Spaces:
Configuration error
Configuration error
johnsonhung
commited on
Commit
·
2a3a041
1
Parent(s):
2d2ef3c
init
Browse files- .gitignore +112 -0
- CODE_OF_CONDUCT.md +5 -0
- CONTRIBUTING.md +36 -0
- LICENSE.md +21 -0
- README.md +119 -12
- app.py +121 -0
- data/README.md +1 -0
- data/demo_imgs/1.jpg +0 -0
- data/demo_imgs/2.jpg +0 -0
- data/demo_imgs/3.jpg +0 -0
- data/demo_imgs/4.jpg +0 -0
- data/demo_imgs/5.jpg +0 -0
- data/demo_imgs/6.jpg +0 -0
- requirements.txt +11 -0
- src/args.py +168 -0
- src/build_vocab.py +409 -0
- src/data_loader.py +193 -0
- src/demo.ipynb +271 -0
- src/demo.py +133 -0
- src/model.py +236 -0
- src/model1_inf.py +43 -0
- src/modules/encoder.py +57 -0
- src/modules/multihead_attention.py +203 -0
- src/modules/transformer_decoder.py +502 -0
- src/modules/utils.py +387 -0
- src/read_pkl.py +7 -0
- src/sample.py +207 -0
- src/sim_ingr.py +197 -0
- src/train.py +398 -0
- src/utils/ims2file.py +94 -0
- src/utils/metrics.py +78 -0
- src/utils/output_ing.py +28 -0
- src/utils/output_utils.py +103 -0
- src/utils/tb_visualizer.py +66 -0
.gitignore
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
*.egg-info/
|
24 |
+
.installed.cfg
|
25 |
+
*.egg
|
26 |
+
MANIFEST
|
27 |
+
|
28 |
+
# PyInstaller
|
29 |
+
# Usually these files are written by a python script from a template
|
30 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
31 |
+
*.manifest
|
32 |
+
*.spec
|
33 |
+
|
34 |
+
# Installer logs
|
35 |
+
pip-log.txt
|
36 |
+
pip-delete-this-directory.txt
|
37 |
+
|
38 |
+
# Unit test / coverage reports
|
39 |
+
htmlcov/
|
40 |
+
.tox/
|
41 |
+
.coverage
|
42 |
+
.coverage.*
|
43 |
+
.cache
|
44 |
+
nosetests.xml
|
45 |
+
coverage.xml
|
46 |
+
*.cover
|
47 |
+
.hypothesis/
|
48 |
+
.pytest_cache/
|
49 |
+
|
50 |
+
# Translations
|
51 |
+
*.mo
|
52 |
+
*.pot
|
53 |
+
|
54 |
+
# Django stuff:
|
55 |
+
*.log
|
56 |
+
local_settings.py
|
57 |
+
db.sqlite3
|
58 |
+
|
59 |
+
# Flask stuff:
|
60 |
+
instance/
|
61 |
+
.webassets-cache
|
62 |
+
|
63 |
+
# Scrapy stuff:
|
64 |
+
.scrapy
|
65 |
+
|
66 |
+
# Sphinx documentation
|
67 |
+
docs/_build/
|
68 |
+
|
69 |
+
# PyBuilder
|
70 |
+
target/
|
71 |
+
|
72 |
+
# Jupyter Notebook
|
73 |
+
.ipynb_checkpoints
|
74 |
+
|
75 |
+
# pyenv
|
76 |
+
.python-version
|
77 |
+
|
78 |
+
# celery beat schedule file
|
79 |
+
celerybeat-schedule
|
80 |
+
|
81 |
+
# SageMath parsed files
|
82 |
+
*.sage.py
|
83 |
+
|
84 |
+
# Environments
|
85 |
+
.env
|
86 |
+
.venv
|
87 |
+
env/
|
88 |
+
venv/
|
89 |
+
ENV/
|
90 |
+
env.bak/
|
91 |
+
venv.bak/
|
92 |
+
|
93 |
+
# Spyder project settings
|
94 |
+
.spyderproject
|
95 |
+
.spyproject
|
96 |
+
|
97 |
+
# Rope project settings
|
98 |
+
.ropeproject
|
99 |
+
|
100 |
+
# mkdocs documentation
|
101 |
+
/site
|
102 |
+
|
103 |
+
# mypy
|
104 |
+
.mypy_cache/
|
105 |
+
*.pkl
|
106 |
+
*.png
|
107 |
+
*.json
|
108 |
+
*.nfs*
|
109 |
+
*.tex
|
110 |
+
.idea/
|
111 |
+
*.bin
|
112 |
+
*.ckpt
|
CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Code of Conduct
|
2 |
+
|
3 |
+
Facebook has adopted a Code of Conduct that we expect project participants to adhere to.
|
4 |
+
Please read the [full text](https://code.fb.com/codeofconduct/)
|
5 |
+
so that you can understand what actions will and will not be tolerated.
|
CONTRIBUTING.md
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contributing
|
2 |
+
We want to make contributing to this project as easy and transparent as
|
3 |
+
possible.
|
4 |
+
|
5 |
+
## Pull Requests
|
6 |
+
We actively welcome your pull requests.
|
7 |
+
|
8 |
+
1. Fork the repo and create your branch from `master`.
|
9 |
+
2. If you've added code that should be tested, add tests.
|
10 |
+
3. If you've changed APIs, update the documentation.
|
11 |
+
4. Ensure the test suite passes.
|
12 |
+
5. Make sure your code lints.
|
13 |
+
6. If you haven't already, complete the Contributor License Agreement ("CLA").
|
14 |
+
|
15 |
+
## Contributor License Agreement ("CLA")
|
16 |
+
In order to accept your pull request, we need you to submit a CLA. You only need
|
17 |
+
to do this once to work on any of Facebook's open source projects.
|
18 |
+
|
19 |
+
Complete your CLA here: <https://code.facebook.com/cla>
|
20 |
+
|
21 |
+
## Issues
|
22 |
+
We use GitHub issues to track public bugs. Please ensure your description is
|
23 |
+
clear and has sufficient instructions to be able to reproduce the issue.
|
24 |
+
|
25 |
+
Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
|
26 |
+
disclosure of security bugs. In those cases, please go through the process
|
27 |
+
outlined on that page and do not file a public issue.
|
28 |
+
|
29 |
+
## Coding Style
|
30 |
+
* 4 spaces for indentation rather than tabs
|
31 |
+
* 100 character line length
|
32 |
+
* PEP8 formatting
|
33 |
+
|
34 |
+
## License
|
35 |
+
By contributing to this project, you agree that your contributions will be licensed
|
36 |
+
under the LICENSE file in the root directory of this source tree.
|
LICENSE.md
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) Facebook, Inc. and its affiliates.
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,12 +1,119 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Inverse Cooking: Recipe Generation from Food Images
|
2 |
+
|
3 |
+
Code supporting the paper:
|
4 |
+
|
5 |
+
*Amaia Salvador, Michal Drozdzal, Xavier Giro-i-Nieto, Adriana Romero.
|
6 |
+
[Inverse Cooking: Recipe Generation from Food Images. ](https://arxiv.org/abs/1812.06164)
|
7 |
+
CVPR 2019*
|
8 |
+
|
9 |
+
|
10 |
+
If you find this code useful in your research, please consider citing using the
|
11 |
+
following BibTeX entry:
|
12 |
+
|
13 |
+
```
|
14 |
+
@InProceedings{Salvador2019inversecooking,
|
15 |
+
author = {Salvador, Amaia and Drozdzal, Michal and Giro-i-Nieto, Xavier and Romero, Adriana},
|
16 |
+
title = {Inverse Cooking: Recipe Generation From Food Images},
|
17 |
+
booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
|
18 |
+
month = {June},
|
19 |
+
year = {2019}
|
20 |
+
}
|
21 |
+
```
|
22 |
+
|
23 |
+
### Installation
|
24 |
+
|
25 |
+
This code uses Python 3.6 and PyTorch 0.4.1 cuda version 9.0.
|
26 |
+
|
27 |
+
- Installing PyTorch:
|
28 |
+
```bash
|
29 |
+
$ conda install pytorch=0.4.1 cuda90 -c pytorch
|
30 |
+
```
|
31 |
+
|
32 |
+
- Install dependencies
|
33 |
+
```bash
|
34 |
+
$ pip install -r requirements.txt
|
35 |
+
```
|
36 |
+
|
37 |
+
### Pretrained model
|
38 |
+
|
39 |
+
- Download ingredient and instruction vocabularies [here](https://dl.fbaipublicfiles.com/inversecooking/ingr_vocab.pkl) and [here](https://dl.fbaipublicfiles.com/inversecooking/instr_vocab.pkl), respectively.
|
40 |
+
- Download pretrained model [here](https://dl.fbaipublicfiles.com/inversecooking/modelbest.ckpt).
|
41 |
+
|
42 |
+
### Demo
|
43 |
+
|
44 |
+
You can use our pretrained model to get recipes for your images.
|
45 |
+
|
46 |
+
Download the required files (listed above), place them under the ```data``` directory, and try our demo notebook ```src/demo.ipynb```.
|
47 |
+
|
48 |
+
Note: The demo will run on GPU if a device is found, else it will use CPU.
|
49 |
+
|
50 |
+
### Data
|
51 |
+
|
52 |
+
- Download [Recipe1M](http://im2recipe.csail.mit.edu/dataset/download) (registration required)
|
53 |
+
- Extract files somewhere (we refer to this path as ```path_to_dataset```).
|
54 |
+
- The contents of ```path_to_dataset``` should be the following:
|
55 |
+
```
|
56 |
+
det_ingrs.json
|
57 |
+
layer1.json
|
58 |
+
layer2.json
|
59 |
+
images/
|
60 |
+
images/train
|
61 |
+
images/val
|
62 |
+
images/test
|
63 |
+
```
|
64 |
+
|
65 |
+
*Note: all python calls below must be run from ```./src```*
|
66 |
+
### Build vocabularies
|
67 |
+
|
68 |
+
```bash
|
69 |
+
$ python build_vocab.py --recipe1m_path path_to_dataset
|
70 |
+
```
|
71 |
+
|
72 |
+
### Images to LMDB (Optional, but recommended)
|
73 |
+
|
74 |
+
For fast loading during training:
|
75 |
+
|
76 |
+
```bash
|
77 |
+
$ python utils/ims2file.py --recipe1m_path path_to_dataset
|
78 |
+
```
|
79 |
+
|
80 |
+
If you decide not to create this file, use the flag ```--load_jpeg``` when training the model.
|
81 |
+
|
82 |
+
### Training
|
83 |
+
|
84 |
+
Create a directory to store checkpoints for all models you train
|
85 |
+
(e.g. ```../checkpoints``` and point ```--save_dir``` to it.)
|
86 |
+
|
87 |
+
We train our model in two stages:
|
88 |
+
|
89 |
+
1. Ingredient prediction from images
|
90 |
+
|
91 |
+
```bash
|
92 |
+
python train.py --model_name im2ingr --batch_size 150 --finetune_after 0 --ingrs_only \
|
93 |
+
--es_metric iou_sample --loss_weight 0 1000.0 1.0 1.0 \
|
94 |
+
--learning_rate 1e-4 --scale_learning_rate_cnn 1.0 \
|
95 |
+
--save_dir ../checkpoints --recipe1m_dir path_to_dataset
|
96 |
+
```
|
97 |
+
|
98 |
+
2. Recipe generation from images and ingredients (loading from 1.)
|
99 |
+
|
100 |
+
```bash
|
101 |
+
python train.py --model_name model --batch_size 256 --recipe_only --transfer_from im2ingr \
|
102 |
+
--save_dir ../checkpoints --recipe1m_dir path_to_dataset
|
103 |
+
```
|
104 |
+
|
105 |
+
Check training progress with Tensorboard from ```../checkpoints```:
|
106 |
+
|
107 |
+
```bash
|
108 |
+
$ tensorboard --logdir='../tb_logs' --port=6006
|
109 |
+
```
|
110 |
+
|
111 |
+
### Evaluation
|
112 |
+
|
113 |
+
- Save generated recipes to disk with
|
114 |
+
```python sample.py --model_name model --save_dir ../checkpoints --recipe1m_dir path_to_dataset --greedy --eval_split test```.
|
115 |
+
- This script will return ingredient metrics (F1 and IoU)
|
116 |
+
|
117 |
+
### License
|
118 |
+
|
119 |
+
inversecooking is released under MIT license, see [LICENSE](LICENSE.md) for details.
|
app.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import requests
|
3 |
+
import pickle
|
4 |
+
from io import BytesIO
|
5 |
+
import gradio as gr
|
6 |
+
from src.args import get_parser
|
7 |
+
from src.model import get_model
|
8 |
+
import torch
|
9 |
+
import os
|
10 |
+
from src.model1_inf import im2ingr
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
response = requests.get("https://i.imgur.com/DwR24EM.jpeg")
|
14 |
+
dog_img = Image.open(BytesIO(response.content))
|
15 |
+
|
16 |
+
def img2ingr(image):
|
17 |
+
# img_file = '../data/demo_imgs/1.jpg'
|
18 |
+
# image = Image.open(img_file).convert('RGB')
|
19 |
+
img = Image.fromarray(np.uint8(image)).convert('RGB')
|
20 |
+
ingr = im2ingr(img, ingrs_vocab, model)
|
21 |
+
return ' '.join(ingr)
|
22 |
+
|
23 |
+
def img_ingr2recipe(image, ingr):
|
24 |
+
print(image.shape, ingr)
|
25 |
+
return dog_img, "A delicious meme dog \n--------\n1. Cook it!\n2. GL&HF"
|
26 |
+
|
27 |
+
def change_checkbox(predicted_ingr):
|
28 |
+
return gr.update(label="Ingredient required", interactive=True, choices=predicted_ingr.split(), value=predicted_ingr.split())
|
29 |
+
|
30 |
+
def add_ingr(new_ingr):
|
31 |
+
print(new_ingr)
|
32 |
+
return "hello"
|
33 |
+
|
34 |
+
def add_to_checkbox(old_ingr, new_ingr):
|
35 |
+
# chack if in dict or not
|
36 |
+
return gr.update(label="Ingredient required", interactive=True, choices=[*old_ingr, new_ingr], value=[*old_ingr, new_ingr])
|
37 |
+
|
38 |
+
|
39 |
+
""" load model1 """
|
40 |
+
args = get_parser()
|
41 |
+
|
42 |
+
# basic parameters
|
43 |
+
model_dir = './data'
|
44 |
+
data_dir = './data'
|
45 |
+
example_dir = './data/demo_imgs/'
|
46 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
47 |
+
map_loc = None if torch.cuda.is_available() else 'cpu'
|
48 |
+
|
49 |
+
# load ingredients vocab
|
50 |
+
ingrs_vocab = pickle.load(open(os.path.join(model_dir, 'ingr_vocab.pkl'), 'rb'))
|
51 |
+
vocab = pickle.load(open(os.path.join(data_dir, 'instr_vocab.pkl'), 'rb'))
|
52 |
+
|
53 |
+
ingr_vocab_size = len(ingrs_vocab)
|
54 |
+
instrs_vocab_size = len(vocab)
|
55 |
+
|
56 |
+
# model setting and loading
|
57 |
+
args.maxseqlen = 15
|
58 |
+
args.ingrs_only=True
|
59 |
+
model = get_model(args, ingr_vocab_size, instrs_vocab_size)
|
60 |
+
model_path = os.path.join(model_dir, 'modelbest.ckpt')
|
61 |
+
model.load_state_dict(torch.load(model_path, map_location=map_loc))
|
62 |
+
model.to(device)
|
63 |
+
model.eval()
|
64 |
+
model.ingrs_only = True
|
65 |
+
model.recipe_only = False
|
66 |
+
|
67 |
+
""" load model2 """
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
+
|
72 |
+
""" gradio """
|
73 |
+
# input image -> list all required ingrs -> checkbox for selecting ingrs / input_box for input more ingrs user want -> output: recipe and its image
|
74 |
+
with gr.Blocks() as demo:
|
75 |
+
gr.Markdown(
|
76 |
+
"""
|
77 |
+
# Recipedia
|
78 |
+
Start finding the yummy recipe ...
|
79 |
+
""")
|
80 |
+
with gr.Tabs():
|
81 |
+
with gr.TabItem("User"):
|
82 |
+
# input image
|
83 |
+
image_input = gr.Image(label="Upload the image of your yummy food", type='filepath')
|
84 |
+
gr.Examples(examples=[example_dir+"1.jpg", example_dir+"2.jpg", example_dir+"3.jpg", example_dir+"4.jpg", example_dir+"5.jpg", example_dir+"6.jpg"], inputs=image_input)
|
85 |
+
with gr.Row():
|
86 |
+
# clear_img_btn = gr.Button("Clear")
|
87 |
+
image_btn = gr.Button("Upload", variant="primary")
|
88 |
+
# list all required ingrs -> checkbox for selecting ingrs / input_box for input more ingrs user want
|
89 |
+
predicted_ingr = gr.Textbox(visible=False)
|
90 |
+
|
91 |
+
with gr.Row():
|
92 |
+
checkboxes = gr.CheckboxGroup(label="Ingredient required", interactive=True)
|
93 |
+
new_ingr = gr.Textbox(label="Addtional ingredients", max_lines=1)
|
94 |
+
# with gr.Row():
|
95 |
+
# new_btn_clear = gr.Button("Clear")
|
96 |
+
# new_btn = gr.Button("Add", variant="primary")
|
97 |
+
|
98 |
+
add_ingr = gr.Textbox(visible=False)
|
99 |
+
|
100 |
+
with gr.Row():
|
101 |
+
clear_ingr_btn = gr.Button("Reset")
|
102 |
+
ingr_btn = gr.Button("Confirm", variant="primary")
|
103 |
+
|
104 |
+
# output: recipe and its image
|
105 |
+
with gr.Row():
|
106 |
+
out_recipe = gr.Textbox(label="Your recipe", value="Spagetti ---\n1. cook it!")
|
107 |
+
out_image = gr.Image(label="Looks yummy ><")
|
108 |
+
|
109 |
+
with gr.TabItem("Example"):
|
110 |
+
image_button = gr.Button("Flip")
|
111 |
+
|
112 |
+
image_btn.click(img2ingr, inputs=image_input, outputs=predicted_ingr)
|
113 |
+
predicted_ingr.change(fn=change_checkbox, inputs=predicted_ingr, outputs=checkboxes)
|
114 |
+
|
115 |
+
# new_btn.click(img2ingr, inputs=new_ingr, outputs=predicted_ingr)
|
116 |
+
new_ingr.submit(fn=add_to_checkbox, inputs=[checkboxes, new_ingr], outputs=checkboxes)
|
117 |
+
|
118 |
+
ingr_btn.click(img_ingr2recipe, inputs=[image_input, checkboxes], outputs=[out_image, out_recipe])
|
119 |
+
|
120 |
+
|
121 |
+
demo.launch(debug=True, share=True)
|
data/README.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Vocabulary file will be saved here
|
data/demo_imgs/1.jpg
ADDED
data/demo_imgs/2.jpg
ADDED
data/demo_imgs/3.jpg
ADDED
data/demo_imgs/4.jpg
ADDED
data/demo_imgs/5.jpg
ADDED
data/demo_imgs/6.jpg
ADDED
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy
|
2 |
+
scipy
|
3 |
+
matplotlib
|
4 |
+
# torch==0.4.1
|
5 |
+
# torchvision==0.2.1
|
6 |
+
nltk
|
7 |
+
Pillow
|
8 |
+
tqdm
|
9 |
+
lmdb
|
10 |
+
tensorflow
|
11 |
+
tensorboardX
|
src/args.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
|
6 |
+
|
7 |
+
def get_parser():
|
8 |
+
|
9 |
+
parser = argparse.ArgumentParser()
|
10 |
+
|
11 |
+
parser.add_argument('--save_dir', type=str, default='path/to/save/models',
|
12 |
+
help='path where checkpoints will be saved')
|
13 |
+
|
14 |
+
parser.add_argument('--project_name', type=str, default='inversecooking',
|
15 |
+
help='name of the directory where models will be saved within save_dir')
|
16 |
+
|
17 |
+
parser.add_argument('--model_name', type=str, default='model',
|
18 |
+
help='save_dir/project_name/model_name will be the path where logs and checkpoints are stored')
|
19 |
+
|
20 |
+
parser.add_argument('--transfer_from', type=str, default='',
|
21 |
+
help='specify model name to transfer from')
|
22 |
+
|
23 |
+
parser.add_argument('--suff', type=str, default='',
|
24 |
+
help='the id of the dictionary to load for training')
|
25 |
+
|
26 |
+
parser.add_argument('--image_model', type=str, default='resnet50', choices=['resnet18', 'resnet50', 'resnet101',
|
27 |
+
'resnet152', 'inception_v3'])
|
28 |
+
|
29 |
+
parser.add_argument('--recipe1m_dir', type=str, default='path/to/recipe1m',
|
30 |
+
help='directory where recipe1m dataset is extracted')
|
31 |
+
|
32 |
+
parser.add_argument('--aux_data_dir', type=str, default='../data',
|
33 |
+
help='path to other necessary data files (eg. vocabularies)')
|
34 |
+
|
35 |
+
parser.add_argument('--crop_size', type=int, default=224, help='size for randomly or center cropping images')
|
36 |
+
|
37 |
+
parser.add_argument('--image_size', type=int, default=256, help='size to rescale images')
|
38 |
+
|
39 |
+
parser.add_argument('--log_step', type=int , default=10, help='step size for printing log info')
|
40 |
+
|
41 |
+
parser.add_argument('--learning_rate', type=float, default=0.001,
|
42 |
+
help='base learning rate')
|
43 |
+
|
44 |
+
parser.add_argument('--scale_learning_rate_cnn', type=float, default=0.01,
|
45 |
+
help='lr multiplier for cnn weights')
|
46 |
+
|
47 |
+
parser.add_argument('--lr_decay_rate', type=float, default=0.99,
|
48 |
+
help='learning rate decay factor')
|
49 |
+
|
50 |
+
parser.add_argument('--lr_decay_every', type=int, default=1,
|
51 |
+
help='frequency of learning rate decay (default is every epoch)')
|
52 |
+
|
53 |
+
parser.add_argument('--weight_decay', type=float, default=0.)
|
54 |
+
|
55 |
+
parser.add_argument('--embed_size', type=int, default=512,
|
56 |
+
help='hidden size for all projections')
|
57 |
+
|
58 |
+
parser.add_argument('--n_att', type=int, default=8,
|
59 |
+
help='number of attention heads in the instruction decoder')
|
60 |
+
|
61 |
+
parser.add_argument('--n_att_ingrs', type=int, default=4,
|
62 |
+
help='number of attention heads in the ingredient decoder')
|
63 |
+
|
64 |
+
parser.add_argument('--transf_layers', type=int, default=16,
|
65 |
+
help='number of transformer layers in the instruction decoder')
|
66 |
+
|
67 |
+
parser.add_argument('--transf_layers_ingrs', type=int, default=4,
|
68 |
+
help='number of transformer layers in the ingredient decoder')
|
69 |
+
|
70 |
+
parser.add_argument('--num_epochs', type=int, default=400,
|
71 |
+
help='maximum number of epochs')
|
72 |
+
|
73 |
+
parser.add_argument('--batch_size', type=int, default=128)
|
74 |
+
|
75 |
+
parser.add_argument('--num_workers', type=int, default=8)
|
76 |
+
|
77 |
+
parser.add_argument('--dropout_encoder', type=float, default=0.3,
|
78 |
+
help='dropout ratio for the image and ingredient encoders')
|
79 |
+
|
80 |
+
parser.add_argument('--dropout_decoder_r', type=float, default=0.3,
|
81 |
+
help='dropout ratio in the instruction decoder')
|
82 |
+
|
83 |
+
parser.add_argument('--dropout_decoder_i', type=float, default=0.3,
|
84 |
+
help='dropout ratio in the ingredient decoder')
|
85 |
+
|
86 |
+
parser.add_argument('--finetune_after', type=int, default=-1,
|
87 |
+
help='epoch to start training cnn. -1 is never, 0 is from the beginning')
|
88 |
+
|
89 |
+
parser.add_argument('--loss_weight', nargs='+', type=float, default=[1.0, 0.0, 0.0, 0.0],
|
90 |
+
help='training loss weights. 1) instruction, 2) ingredient, 3) eos 4) cardinality')
|
91 |
+
|
92 |
+
parser.add_argument('--max_eval', type=int, default=4096,
|
93 |
+
help='number of validation samples to evaluate during training')
|
94 |
+
|
95 |
+
parser.add_argument('--label_smoothing_ingr', type=float, default=0.1,
|
96 |
+
help='label smoothing for bce loss for ingredients')
|
97 |
+
|
98 |
+
parser.add_argument('--patience', type=int, default=50,
|
99 |
+
help='maximum number of epochs to allow before early stopping')
|
100 |
+
|
101 |
+
parser.add_argument('--maxseqlen', type=int, default=15,
|
102 |
+
help='maximum length of each instruction')
|
103 |
+
|
104 |
+
parser.add_argument('--maxnuminstrs', type=int, default=10,
|
105 |
+
help='maximum number of instructions')
|
106 |
+
|
107 |
+
parser.add_argument('--maxnumims', type=int, default=5,
|
108 |
+
help='maximum number of images per sample')
|
109 |
+
|
110 |
+
parser.add_argument('--maxnumlabels', type=int, default=20,
|
111 |
+
help='maximum number of ingredients per sample')
|
112 |
+
|
113 |
+
parser.add_argument('--es_metric', type=str, default='loss', choices=['loss', 'iou_sample'],
|
114 |
+
help='early stopping metric to track')
|
115 |
+
|
116 |
+
parser.add_argument('--eval_split', type=str, default='val')
|
117 |
+
|
118 |
+
parser.add_argument('--numgens', type=int, default=3)
|
119 |
+
|
120 |
+
parser.add_argument('--greedy', dest='greedy', action='store_true',
|
121 |
+
help='enables greedy sampling (inference only)')
|
122 |
+
parser.set_defaults(greedy=False)
|
123 |
+
|
124 |
+
parser.add_argument('--temperature', type=float, default=1.0,
|
125 |
+
help='sampling temperature (when greedy is False)')
|
126 |
+
|
127 |
+
parser.add_argument('--beam', type=int, default=-1,
|
128 |
+
help='beam size. -1 means no beam search (either greedy or sampling)')
|
129 |
+
|
130 |
+
parser.add_argument('--ingrs_only', dest='ingrs_only', action='store_true',
|
131 |
+
help='train or evaluate the model only for ingredient prediction')
|
132 |
+
parser.set_defaults(ingrs_only=False)
|
133 |
+
|
134 |
+
parser.add_argument('--recipe_only', dest='recipe_only', action='store_true',
|
135 |
+
help='train or evaluate the model only for instruction generation')
|
136 |
+
parser.set_defaults(recipe_only=False)
|
137 |
+
|
138 |
+
parser.add_argument('--log_term', dest='log_term', action='store_true',
|
139 |
+
help='if used, shows training log in stdout instead of saving it to a file.')
|
140 |
+
parser.set_defaults(log_term=False)
|
141 |
+
|
142 |
+
parser.add_argument('--notensorboard', dest='tensorboard', action='store_false',
|
143 |
+
help='if used, tensorboard logs will not be saved')
|
144 |
+
parser.set_defaults(tensorboard=True)
|
145 |
+
|
146 |
+
parser.add_argument('--resume', dest='resume', action='store_true',
|
147 |
+
help='resume training from the checkpoint in model_name')
|
148 |
+
parser.set_defaults(resume=False)
|
149 |
+
|
150 |
+
parser.add_argument('--nodecay_lr', dest='decay_lr', action='store_false',
|
151 |
+
help='disables learning rate decay')
|
152 |
+
parser.set_defaults(decay_lr=True)
|
153 |
+
|
154 |
+
parser.add_argument('--load_jpeg', dest='use_lmdb', action='store_false',
|
155 |
+
help='if used, images are loaded from jpg files instead of lmdb')
|
156 |
+
parser.set_defaults(use_lmdb=True)
|
157 |
+
|
158 |
+
parser.add_argument('--get_perplexity', dest='get_perplexity', action='store_true',
|
159 |
+
help='used to get perplexity in evaluation')
|
160 |
+
parser.set_defaults(get_perplexity=False)
|
161 |
+
|
162 |
+
parser.add_argument('--use_true_ingrs', dest='use_true_ingrs', action='store_true',
|
163 |
+
help='if used, true ingredients will be used as input to obtain the recipe in evaluation')
|
164 |
+
parser.set_defaults(use_true_ingrs=False)
|
165 |
+
|
166 |
+
args = parser.parse_args()
|
167 |
+
|
168 |
+
return args
|
src/build_vocab.py
ADDED
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
|
3 |
+
import nltk
|
4 |
+
import pickle
|
5 |
+
import argparse
|
6 |
+
from collections import Counter
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
from tqdm import *
|
10 |
+
import numpy as np
|
11 |
+
import re
|
12 |
+
|
13 |
+
|
14 |
+
class Vocabulary(object):
|
15 |
+
"""Simple vocabulary wrapper."""
|
16 |
+
def __init__(self):
|
17 |
+
self.word2idx = {}
|
18 |
+
self.idx2word = {}
|
19 |
+
self.idx = 0
|
20 |
+
|
21 |
+
def add_word(self, word, idx=None):
|
22 |
+
if idx is None:
|
23 |
+
if not word in self.word2idx:
|
24 |
+
self.word2idx[word] = self.idx
|
25 |
+
self.idx2word[self.idx] = word
|
26 |
+
self.idx += 1
|
27 |
+
return self.idx
|
28 |
+
else:
|
29 |
+
if not word in self.word2idx:
|
30 |
+
self.word2idx[word] = idx
|
31 |
+
if idx in self.idx2word.keys():
|
32 |
+
self.idx2word[idx].append(word)
|
33 |
+
else:
|
34 |
+
self.idx2word[idx] = [word]
|
35 |
+
|
36 |
+
return idx
|
37 |
+
|
38 |
+
def __call__(self, word):
|
39 |
+
if not word in self.word2idx:
|
40 |
+
return self.word2idx['<pad>']
|
41 |
+
return self.word2idx[word]
|
42 |
+
|
43 |
+
def __len__(self):
|
44 |
+
return len(self.idx2word)
|
45 |
+
|
46 |
+
|
47 |
+
def get_ingredient(det_ingr, replace_dict):
|
48 |
+
det_ingr_undrs = det_ingr['text'].lower()
|
49 |
+
det_ingr_undrs = ''.join(i for i in det_ingr_undrs if not i.isdigit())
|
50 |
+
|
51 |
+
for rep, char_list in replace_dict.items():
|
52 |
+
for c_ in char_list:
|
53 |
+
if c_ in det_ingr_undrs:
|
54 |
+
det_ingr_undrs = det_ingr_undrs.replace(c_, rep)
|
55 |
+
det_ingr_undrs = det_ingr_undrs.strip()
|
56 |
+
det_ingr_undrs = det_ingr_undrs.replace(' ', '_')
|
57 |
+
|
58 |
+
return det_ingr_undrs
|
59 |
+
|
60 |
+
|
61 |
+
def get_instruction(instruction, replace_dict, instruction_mode=True):
|
62 |
+
instruction = instruction.lower()
|
63 |
+
|
64 |
+
for rep, char_list in replace_dict.items():
|
65 |
+
for c_ in char_list:
|
66 |
+
if c_ in instruction:
|
67 |
+
instruction = instruction.replace(c_, rep)
|
68 |
+
instruction = instruction.strip()
|
69 |
+
# remove sentences starting with "1.", "2.", ... from the targets
|
70 |
+
if len(instruction) > 0 and instruction[0].isdigit() and instruction_mode:
|
71 |
+
instruction = ''
|
72 |
+
return instruction
|
73 |
+
|
74 |
+
|
75 |
+
def remove_plurals(counter_ingrs, ingr_clusters):
|
76 |
+
del_ingrs = []
|
77 |
+
|
78 |
+
for k, v in counter_ingrs.items():
|
79 |
+
|
80 |
+
if len(k) == 0:
|
81 |
+
del_ingrs.append(k)
|
82 |
+
continue
|
83 |
+
|
84 |
+
gotit = 0
|
85 |
+
if k[-2:] == 'es':
|
86 |
+
if k[:-2] in counter_ingrs.keys():
|
87 |
+
counter_ingrs[k[:-2]] += v
|
88 |
+
ingr_clusters[k[:-2]].extend(ingr_clusters[k])
|
89 |
+
del_ingrs.append(k)
|
90 |
+
gotit = 1
|
91 |
+
|
92 |
+
if k[-1] == 's' and gotit == 0:
|
93 |
+
if k[:-1] in counter_ingrs.keys():
|
94 |
+
counter_ingrs[k[:-1]] += v
|
95 |
+
ingr_clusters[k[:-1]].extend(ingr_clusters[k])
|
96 |
+
del_ingrs.append(k)
|
97 |
+
for item in del_ingrs:
|
98 |
+
del counter_ingrs[item]
|
99 |
+
del ingr_clusters[item]
|
100 |
+
return counter_ingrs, ingr_clusters
|
101 |
+
|
102 |
+
|
103 |
+
def cluster_ingredients(counter_ingrs):
|
104 |
+
mydict = dict()
|
105 |
+
mydict_ingrs = dict()
|
106 |
+
|
107 |
+
for k, v in counter_ingrs.items():
|
108 |
+
|
109 |
+
w1 = k.split('_')[-1]
|
110 |
+
w2 = k.split('_')[0]
|
111 |
+
lw = [w1, w2]
|
112 |
+
if len(k.split('_')) > 1:
|
113 |
+
w3 = k.split('_')[0] + '_' + k.split('_')[1]
|
114 |
+
w4 = k.split('_')[-2] + '_' + k.split('_')[-1]
|
115 |
+
|
116 |
+
lw = [w1, w2, w4, w3]
|
117 |
+
|
118 |
+
gotit = 0
|
119 |
+
for w in lw:
|
120 |
+
if w in counter_ingrs.keys():
|
121 |
+
# check if its parts are
|
122 |
+
parts = w.split('_')
|
123 |
+
if len(parts) > 0:
|
124 |
+
if parts[0] in counter_ingrs.keys():
|
125 |
+
w = parts[0]
|
126 |
+
elif parts[1] in counter_ingrs.keys():
|
127 |
+
w = parts[1]
|
128 |
+
if w in mydict.keys():
|
129 |
+
mydict[w] += v
|
130 |
+
mydict_ingrs[w].append(k)
|
131 |
+
else:
|
132 |
+
mydict[w] = v
|
133 |
+
mydict_ingrs[w] = [k]
|
134 |
+
gotit = 1
|
135 |
+
break
|
136 |
+
if gotit == 0:
|
137 |
+
mydict[k] = v
|
138 |
+
mydict_ingrs[k] = [k]
|
139 |
+
|
140 |
+
return mydict, mydict_ingrs
|
141 |
+
|
142 |
+
|
143 |
+
def update_counter(list_, counter_toks, istrain=False):
|
144 |
+
for sentence in list_:
|
145 |
+
tokens = nltk.tokenize.word_tokenize(sentence)
|
146 |
+
if istrain:
|
147 |
+
counter_toks.update(tokens)
|
148 |
+
|
149 |
+
|
150 |
+
def build_vocab_recipe1m(args):
|
151 |
+
print ("Loading data...")
|
152 |
+
dets = json.load(open(os.path.join(args.recipe1m_path, 'det_ingrs.json'), 'r'))
|
153 |
+
layer1 = json.load(open(os.path.join(args.recipe1m_path, 'layer1.json'), 'r'))
|
154 |
+
layer2 = json.load(open(os.path.join(args.recipe1m_path, 'layer2.json'), 'r'))
|
155 |
+
|
156 |
+
id2im = {}
|
157 |
+
|
158 |
+
for i, entry in enumerate(layer2):
|
159 |
+
id2im[entry['id']] = i
|
160 |
+
|
161 |
+
print("Loaded data.")
|
162 |
+
print("Found %d recipes in the dataset." % (len(layer1)))
|
163 |
+
replace_dict_ingrs = {'and': ['&', "'n"], '': ['%', ',', '.', '#', '[', ']', '!', '?']}
|
164 |
+
replace_dict_instrs = {'and': ['&', "'n"], '': ['#', '[', ']']}
|
165 |
+
|
166 |
+
idx2ind = {}
|
167 |
+
for i, entry in enumerate(dets):
|
168 |
+
idx2ind[entry['id']] = i
|
169 |
+
|
170 |
+
ingrs_file = args.save_path + 'allingrs_count.pkl'
|
171 |
+
instrs_file = args.save_path + 'allwords_count.pkl'
|
172 |
+
|
173 |
+
#####
|
174 |
+
# 1. Count words in dataset and clean
|
175 |
+
#####
|
176 |
+
if os.path.exists(ingrs_file) and os.path.exists(instrs_file) and not args.forcegen:
|
177 |
+
print ("loading pre-extracted word counters")
|
178 |
+
counter_ingrs = pickle.load(open(args.save_path + 'allingrs_count.pkl', 'rb'))
|
179 |
+
counter_toks = pickle.load(open(args.save_path + 'allwords_count.pkl', 'rb'))
|
180 |
+
else:
|
181 |
+
counter_toks = Counter()
|
182 |
+
counter_ingrs = Counter()
|
183 |
+
counter_ingrs_raw = Counter()
|
184 |
+
|
185 |
+
for i, entry in tqdm(enumerate(layer1)):
|
186 |
+
|
187 |
+
# get all instructions for this recipe
|
188 |
+
instrs = entry['instructions']
|
189 |
+
|
190 |
+
instrs_list = []
|
191 |
+
ingrs_list = []
|
192 |
+
|
193 |
+
# retrieve pre-detected ingredients for this entry
|
194 |
+
det_ingrs = dets[idx2ind[entry['id']]]['ingredients']
|
195 |
+
|
196 |
+
valid = dets[idx2ind[entry['id']]]['valid']
|
197 |
+
det_ingrs_filtered = []
|
198 |
+
|
199 |
+
for j, det_ingr in enumerate(det_ingrs):
|
200 |
+
if len(det_ingr) > 0 and valid[j]:
|
201 |
+
det_ingr_undrs = get_ingredient(det_ingr, replace_dict_ingrs)
|
202 |
+
det_ingrs_filtered.append(det_ingr_undrs)
|
203 |
+
ingrs_list.append(det_ingr_undrs)
|
204 |
+
|
205 |
+
# get raw text for instructions of this entry
|
206 |
+
acc_len = 0
|
207 |
+
for instr in instrs:
|
208 |
+
instr = instr['text']
|
209 |
+
instr = get_instruction(instr, replace_dict_instrs)
|
210 |
+
if len(instr) > 0:
|
211 |
+
instrs_list.append(instr)
|
212 |
+
acc_len += len(instr)
|
213 |
+
|
214 |
+
# discard recipes with too few or too many ingredients or instruction words
|
215 |
+
if len(ingrs_list) < args.minnumingrs or len(instrs_list) < args.minnuminstrs \
|
216 |
+
or len(instrs_list) >= args.maxnuminstrs or len(ingrs_list) >= args.maxnumingrs \
|
217 |
+
or acc_len < args.minnumwords:
|
218 |
+
continue
|
219 |
+
|
220 |
+
# tokenize sentences and update counter
|
221 |
+
update_counter(instrs_list, counter_toks, istrain=entry['partition'] == 'train')
|
222 |
+
title = nltk.tokenize.word_tokenize(entry['title'].lower())
|
223 |
+
if entry['partition'] == 'train':
|
224 |
+
counter_toks.update(title)
|
225 |
+
if entry['partition'] == 'train':
|
226 |
+
counter_ingrs.update(ingrs_list)
|
227 |
+
|
228 |
+
pickle.dump(counter_ingrs, open(args.save_path + 'allingrs_count.pkl', 'wb'))
|
229 |
+
pickle.dump(counter_toks, open(args.save_path + 'allwords_count.pkl', 'wb'))
|
230 |
+
pickle.dump(counter_ingrs_raw, open(args.save_path + 'allingrs_raw_count.pkl', 'wb'))
|
231 |
+
|
232 |
+
# manually add missing entries for better clustering
|
233 |
+
base_words = ['peppers', 'tomato', 'spinach_leaves', 'turkey_breast', 'lettuce_leaf',
|
234 |
+
'chicken_thighs', 'milk_powder', 'bread_crumbs', 'onion_flakes',
|
235 |
+
'red_pepper', 'pepper_flakes', 'juice_concentrate', 'cracker_crumbs', 'hot_chili',
|
236 |
+
'seasoning_mix', 'dill_weed', 'pepper_sauce', 'sprouts', 'cooking_spray', 'cheese_blend',
|
237 |
+
'basil_leaves', 'pineapple_chunks', 'marshmallow', 'chile_powder',
|
238 |
+
'cheese_blend', 'corn_kernels', 'tomato_sauce', 'chickens', 'cracker_crust',
|
239 |
+
'lemonade_concentrate', 'red_chili', 'mushroom_caps', 'mushroom_cap', 'breaded_chicken',
|
240 |
+
'frozen_pineapple', 'pineapple_chunks', 'seasoning_mix', 'seaweed', 'onion_flakes',
|
241 |
+
'bouillon_granules', 'lettuce_leaf', 'stuffing_mix', 'parsley_flakes', 'chicken_breast',
|
242 |
+
'basil_leaves', 'baguettes', 'green_tea', 'peanut_butter', 'green_onion', 'fresh_cilantro',
|
243 |
+
'breaded_chicken', 'hot_pepper', 'dried_lavender', 'white_chocolate',
|
244 |
+
'dill_weed', 'cake_mix', 'cheese_spread', 'turkey_breast', 'chucken_thighs', 'basil_leaves',
|
245 |
+
'mandarin_orange', 'laurel', 'cabbage_head', 'pistachio', 'cheese_dip',
|
246 |
+
'thyme_leave', 'boneless_pork', 'red_pepper', 'onion_dip', 'skinless_chicken', 'dark_chocolate',
|
247 |
+
'canned_corn', 'muffin', 'cracker_crust', 'bread_crumbs', 'frozen_broccoli',
|
248 |
+
'philadelphia', 'cracker_crust', 'chicken_breast']
|
249 |
+
|
250 |
+
for base_word in base_words:
|
251 |
+
|
252 |
+
if base_word not in counter_ingrs.keys():
|
253 |
+
counter_ingrs[base_word] = 1
|
254 |
+
|
255 |
+
counter_ingrs, cluster_ingrs = cluster_ingredients(counter_ingrs)
|
256 |
+
counter_ingrs, cluster_ingrs = remove_plurals(counter_ingrs, cluster_ingrs)
|
257 |
+
|
258 |
+
# If the word frequency is less than 'threshold', then the word is discarded.
|
259 |
+
words = [word for word, cnt in counter_toks.items() if cnt >= args.threshold_words]
|
260 |
+
ingrs = {word: cnt for word, cnt in counter_ingrs.items() if cnt >= args.threshold_ingrs}
|
261 |
+
|
262 |
+
# Recipe vocab
|
263 |
+
# Create a vocab wrapper and add some special tokens.
|
264 |
+
vocab_toks = Vocabulary()
|
265 |
+
vocab_toks.add_word('<start>')
|
266 |
+
vocab_toks.add_word('<end>')
|
267 |
+
vocab_toks.add_word('<eoi>')
|
268 |
+
|
269 |
+
# Add the words to the vocabulary.
|
270 |
+
for i, word in enumerate(words):
|
271 |
+
vocab_toks.add_word(word)
|
272 |
+
vocab_toks.add_word('<pad>')
|
273 |
+
|
274 |
+
# Ingredient vocab
|
275 |
+
# Create a vocab wrapper for ingredients
|
276 |
+
vocab_ingrs = Vocabulary()
|
277 |
+
idx = vocab_ingrs.add_word('<end>')
|
278 |
+
# this returns the next idx to add words to
|
279 |
+
# Add the ingredients to the vocabulary.
|
280 |
+
for k, _ in ingrs.items():
|
281 |
+
for ingr in cluster_ingrs[k]:
|
282 |
+
idx = vocab_ingrs.add_word(ingr, idx)
|
283 |
+
idx += 1
|
284 |
+
_ = vocab_ingrs.add_word('<pad>', idx)
|
285 |
+
|
286 |
+
print("Total ingr vocabulary size: {}".format(len(vocab_ingrs)))
|
287 |
+
print("Total token vocabulary size: {}".format(len(vocab_toks)))
|
288 |
+
|
289 |
+
dataset = {'train': [], 'val': [], 'test': []}
|
290 |
+
|
291 |
+
######
|
292 |
+
# 2. Tokenize and build dataset based on vocabularies.
|
293 |
+
######
|
294 |
+
for i, entry in tqdm(enumerate(layer1)):
|
295 |
+
|
296 |
+
# get all instructions for this recipe
|
297 |
+
instrs = entry['instructions']
|
298 |
+
|
299 |
+
instrs_list = []
|
300 |
+
ingrs_list = []
|
301 |
+
images_list = []
|
302 |
+
|
303 |
+
# retrieve pre-detected ingredients for this entry
|
304 |
+
det_ingrs = dets[idx2ind[entry['id']]]['ingredients']
|
305 |
+
valid = dets[idx2ind[entry['id']]]['valid']
|
306 |
+
labels = []
|
307 |
+
|
308 |
+
for j, det_ingr in enumerate(det_ingrs):
|
309 |
+
if len(det_ingr) > 0 and valid[j]:
|
310 |
+
det_ingr_undrs = get_ingredient(det_ingr, replace_dict_ingrs)
|
311 |
+
ingrs_list.append(det_ingr_undrs)
|
312 |
+
label_idx = vocab_ingrs(det_ingr_undrs)
|
313 |
+
if label_idx is not vocab_ingrs('<pad>') and label_idx not in labels:
|
314 |
+
labels.append(label_idx)
|
315 |
+
|
316 |
+
# get raw text for instructions of this entry
|
317 |
+
acc_len = 0
|
318 |
+
for instr in instrs:
|
319 |
+
instr = instr['text']
|
320 |
+
instr = get_instruction(instr, replace_dict_instrs)
|
321 |
+
if len(instr) > 0:
|
322 |
+
acc_len += len(instr)
|
323 |
+
instrs_list.append(instr)
|
324 |
+
|
325 |
+
# we discard recipes with too many or too few ingredients or instruction words
|
326 |
+
if len(labels) < args.minnumingrs or len(instrs_list) < args.minnuminstrs \
|
327 |
+
or len(instrs_list) >= args.maxnuminstrs or len(labels) >= args.maxnumingrs \
|
328 |
+
or acc_len < args.minnumwords:
|
329 |
+
continue
|
330 |
+
|
331 |
+
if entry['id'] in id2im.keys():
|
332 |
+
ims = layer2[id2im[entry['id']]]
|
333 |
+
|
334 |
+
# copy image paths for this recipe
|
335 |
+
for im in ims['images']:
|
336 |
+
images_list.append(im['id'])
|
337 |
+
|
338 |
+
# tokenize sentences
|
339 |
+
toks = []
|
340 |
+
|
341 |
+
for instr in instrs_list:
|
342 |
+
tokens = nltk.tokenize.word_tokenize(instr)
|
343 |
+
toks.append(tokens)
|
344 |
+
|
345 |
+
title = nltk.tokenize.word_tokenize(entry['title'].lower())
|
346 |
+
|
347 |
+
newentry = {'id': entry['id'], 'instructions': instrs_list, 'tokenized': toks,
|
348 |
+
'ingredients': ingrs_list, 'images': images_list, 'title': title}
|
349 |
+
dataset[entry['partition']].append(newentry)
|
350 |
+
|
351 |
+
print('Dataset size:')
|
352 |
+
for split in dataset.keys():
|
353 |
+
print(split, ':', len(dataset[split]))
|
354 |
+
|
355 |
+
return vocab_ingrs, vocab_toks, dataset
|
356 |
+
|
357 |
+
|
358 |
+
def main(args):
|
359 |
+
|
360 |
+
vocab_ingrs, vocab_toks, dataset = build_vocab_recipe1m(args)
|
361 |
+
|
362 |
+
with open(os.path.join(args.save_path, args.suff+'recipe1m_vocab_ingrs.pkl'), 'wb') as f:
|
363 |
+
pickle.dump(vocab_ingrs, f)
|
364 |
+
with open(os.path.join(args.save_path, args.suff+'recipe1m_vocab_toks.pkl'), 'wb') as f:
|
365 |
+
pickle.dump(vocab_toks, f)
|
366 |
+
|
367 |
+
for split in dataset.keys():
|
368 |
+
with open(os.path.join(args.save_path, args.suff+'recipe1m_' + split + '.pkl'), 'wb') as f:
|
369 |
+
pickle.dump(dataset[split], f)
|
370 |
+
|
371 |
+
|
372 |
+
if __name__ == '__main__':
|
373 |
+
|
374 |
+
parser = argparse.ArgumentParser()
|
375 |
+
parser.add_argument('--recipe1m_path', type=str,
|
376 |
+
default='path/to/recipe1m',
|
377 |
+
help='recipe1m path')
|
378 |
+
|
379 |
+
parser.add_argument('--save_path', type=str, default='../data/',
|
380 |
+
help='path for saving vocabulary wrapper')
|
381 |
+
|
382 |
+
parser.add_argument('--suff', type=str, default='')
|
383 |
+
|
384 |
+
parser.add_argument('--threshold_ingrs', type=int, default=10,
|
385 |
+
help='minimum ingr count threshold')
|
386 |
+
|
387 |
+
parser.add_argument('--threshold_words', type=int, default=10,
|
388 |
+
help='minimum word count threshold')
|
389 |
+
|
390 |
+
parser.add_argument('--maxnuminstrs', type=int, default=20,
|
391 |
+
help='max number of instructions (sentences)')
|
392 |
+
|
393 |
+
parser.add_argument('--maxnumingrs', type=int, default=20,
|
394 |
+
help='max number of ingredients')
|
395 |
+
|
396 |
+
parser.add_argument('--minnuminstrs', type=int, default=2,
|
397 |
+
help='max number of instructions (sentences)')
|
398 |
+
|
399 |
+
parser.add_argument('--minnumingrs', type=int, default=2,
|
400 |
+
help='max number of ingredients')
|
401 |
+
|
402 |
+
parser.add_argument('--minnumwords', type=int, default=20,
|
403 |
+
help='minimum number of characters in recipe')
|
404 |
+
|
405 |
+
parser.add_argument('--forcegen', dest='forcegen', action='store_true')
|
406 |
+
parser.set_defaults(forcegen=False)
|
407 |
+
|
408 |
+
args = parser.parse_args()
|
409 |
+
main(args)
|
src/data_loader.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torchvision.transforms as transforms
|
5 |
+
import torch.utils.data as data
|
6 |
+
import os
|
7 |
+
import pickle
|
8 |
+
import numpy as np
|
9 |
+
import nltk
|
10 |
+
from PIL import Image
|
11 |
+
from build_vocab import Vocabulary
|
12 |
+
import random
|
13 |
+
import json
|
14 |
+
import lmdb
|
15 |
+
|
16 |
+
|
17 |
+
class Recipe1MDataset(data.Dataset):
|
18 |
+
|
19 |
+
def __init__(self, data_dir, aux_data_dir, split, maxseqlen, maxnuminstrs, maxnumlabels, maxnumims,
|
20 |
+
transform=None, max_num_samples=-1, use_lmdb=False, suff=''):
|
21 |
+
|
22 |
+
self.ingrs_vocab = pickle.load(open(os.path.join(aux_data_dir, suff + 'recipe1m_vocab_ingrs.pkl'), 'rb'))
|
23 |
+
self.instrs_vocab = pickle.load(open(os.path.join(aux_data_dir, suff + 'recipe1m_vocab_toks.pkl'), 'rb'))
|
24 |
+
self.dataset = pickle.load(open(os.path.join(aux_data_dir, suff + 'recipe1m_'+split+'.pkl'), 'rb'))
|
25 |
+
|
26 |
+
self.label2word = self.get_ingrs_vocab()
|
27 |
+
|
28 |
+
self.use_lmdb = use_lmdb
|
29 |
+
if use_lmdb:
|
30 |
+
self.image_file = lmdb.open(os.path.join(aux_data_dir, 'lmdb_' + split), max_readers=1, readonly=True,
|
31 |
+
lock=False, readahead=False, meminit=False)
|
32 |
+
|
33 |
+
self.ids = []
|
34 |
+
self.split = split
|
35 |
+
for i, entry in enumerate(self.dataset):
|
36 |
+
if len(entry['images']) == 0:
|
37 |
+
continue
|
38 |
+
self.ids.append(i)
|
39 |
+
|
40 |
+
self.root = os.path.join(data_dir, 'images', split)
|
41 |
+
self.transform = transform
|
42 |
+
self.max_num_labels = maxnumlabels
|
43 |
+
self.maxseqlen = maxseqlen
|
44 |
+
self.max_num_instrs = maxnuminstrs
|
45 |
+
self.maxseqlen = maxseqlen*maxnuminstrs
|
46 |
+
self.maxnumims = maxnumims
|
47 |
+
if max_num_samples != -1:
|
48 |
+
random.shuffle(self.ids)
|
49 |
+
self.ids = self.ids[:max_num_samples]
|
50 |
+
|
51 |
+
def get_instrs_vocab(self):
|
52 |
+
return self.instrs_vocab
|
53 |
+
|
54 |
+
def get_instrs_vocab_size(self):
|
55 |
+
return len(self.instrs_vocab)
|
56 |
+
|
57 |
+
def get_ingrs_vocab(self):
|
58 |
+
return [min(w, key=len) if not isinstance(w, str) else w for w in
|
59 |
+
self.ingrs_vocab.idx2word.values()] # includes 'pad' ingredient
|
60 |
+
|
61 |
+
def get_ingrs_vocab_size(self):
|
62 |
+
return len(self.ingrs_vocab)
|
63 |
+
|
64 |
+
def __getitem__(self, index):
|
65 |
+
"""Returns one data pair (image and caption)."""
|
66 |
+
|
67 |
+
sample = self.dataset[self.ids[index]]
|
68 |
+
img_id = sample['id']
|
69 |
+
captions = sample['tokenized']
|
70 |
+
paths = sample['images'][0:self.maxnumims]
|
71 |
+
|
72 |
+
idx = index
|
73 |
+
|
74 |
+
labels = self.dataset[self.ids[idx]]['ingredients']
|
75 |
+
title = sample['title']
|
76 |
+
|
77 |
+
tokens = []
|
78 |
+
tokens.extend(title)
|
79 |
+
# add fake token to separate title from recipe
|
80 |
+
tokens.append('<eoi>')
|
81 |
+
for c in captions:
|
82 |
+
tokens.extend(c)
|
83 |
+
tokens.append('<eoi>')
|
84 |
+
|
85 |
+
ilabels_gt = np.ones(self.max_num_labels) * self.ingrs_vocab('<pad>')
|
86 |
+
pos = 0
|
87 |
+
|
88 |
+
true_ingr_idxs = []
|
89 |
+
for i in range(len(labels)):
|
90 |
+
true_ingr_idxs.append(self.ingrs_vocab(labels[i]))
|
91 |
+
|
92 |
+
for i in range(self.max_num_labels):
|
93 |
+
if i >= len(labels):
|
94 |
+
label = '<pad>'
|
95 |
+
else:
|
96 |
+
label = labels[i]
|
97 |
+
label_idx = self.ingrs_vocab(label)
|
98 |
+
if label_idx not in ilabels_gt:
|
99 |
+
ilabels_gt[pos] = label_idx
|
100 |
+
pos += 1
|
101 |
+
|
102 |
+
ilabels_gt[pos] = self.ingrs_vocab('<end>')
|
103 |
+
ingrs_gt = torch.from_numpy(ilabels_gt).long()
|
104 |
+
|
105 |
+
if len(paths) == 0:
|
106 |
+
path = None
|
107 |
+
image_input = torch.zeros((3, 224, 224))
|
108 |
+
else:
|
109 |
+
if self.split == 'train':
|
110 |
+
img_idx = np.random.randint(0, len(paths))
|
111 |
+
else:
|
112 |
+
img_idx = 0
|
113 |
+
path = paths[img_idx]
|
114 |
+
if self.use_lmdb:
|
115 |
+
try:
|
116 |
+
with self.image_file.begin(write=False) as txn:
|
117 |
+
image = txn.get(path.encode())
|
118 |
+
image = np.fromstring(image, dtype=np.uint8)
|
119 |
+
image = np.reshape(image, (256, 256, 3))
|
120 |
+
image = Image.fromarray(image.astype('uint8'), 'RGB')
|
121 |
+
except:
|
122 |
+
print ("Image id not found in lmdb. Loading jpeg file...")
|
123 |
+
image = Image.open(os.path.join(self.root, path[0], path[1],
|
124 |
+
path[2], path[3], path)).convert('RGB')
|
125 |
+
else:
|
126 |
+
image = Image.open(os.path.join(self.root, path[0], path[1], path[2], path[3], path)).convert('RGB')
|
127 |
+
if self.transform is not None:
|
128 |
+
image = self.transform(image)
|
129 |
+
image_input = image
|
130 |
+
|
131 |
+
# Convert caption (string) to word ids.
|
132 |
+
caption = []
|
133 |
+
|
134 |
+
caption = self.caption_to_idxs(tokens, caption)
|
135 |
+
caption.append(self.instrs_vocab('<end>'))
|
136 |
+
|
137 |
+
caption = caption[0:self.maxseqlen]
|
138 |
+
target = torch.Tensor(caption)
|
139 |
+
|
140 |
+
return image_input, target, ingrs_gt, img_id, path, self.instrs_vocab('<pad>')
|
141 |
+
|
142 |
+
def __len__(self):
|
143 |
+
return len(self.ids)
|
144 |
+
|
145 |
+
def caption_to_idxs(self, tokens, caption):
|
146 |
+
|
147 |
+
caption.append(self.instrs_vocab('<start>'))
|
148 |
+
for token in tokens:
|
149 |
+
caption.append(self.instrs_vocab(token))
|
150 |
+
return caption
|
151 |
+
|
152 |
+
|
153 |
+
def collate_fn(data):
|
154 |
+
|
155 |
+
# Sort a data list by caption length (descending order).
|
156 |
+
# data.sort(key=lambda x: len(x[2]), reverse=True)
|
157 |
+
image_input, captions, ingrs_gt, img_id, path, pad_value = zip(*data)
|
158 |
+
|
159 |
+
# Merge images (from tuple of 3D tensor to 4D tensor).
|
160 |
+
|
161 |
+
image_input = torch.stack(image_input, 0)
|
162 |
+
ingrs_gt = torch.stack(ingrs_gt, 0)
|
163 |
+
|
164 |
+
# Merge captions (from tuple of 1D tensor to 2D tensor).
|
165 |
+
lengths = [len(cap) for cap in captions]
|
166 |
+
targets = torch.ones(len(captions), max(lengths)).long()*pad_value[0]
|
167 |
+
|
168 |
+
for i, cap in enumerate(captions):
|
169 |
+
end = lengths[i]
|
170 |
+
targets[i, :end] = cap[:end]
|
171 |
+
|
172 |
+
return image_input, targets, ingrs_gt, img_id, path
|
173 |
+
|
174 |
+
|
175 |
+
def get_loader(data_dir, aux_data_dir, split, maxseqlen,
|
176 |
+
maxnuminstrs, maxnumlabels, maxnumims, transform, batch_size,
|
177 |
+
shuffle, num_workers, drop_last=False,
|
178 |
+
max_num_samples=-1,
|
179 |
+
use_lmdb=False,
|
180 |
+
suff=''):
|
181 |
+
|
182 |
+
dataset = Recipe1MDataset(data_dir=data_dir, aux_data_dir=aux_data_dir, split=split,
|
183 |
+
maxseqlen=maxseqlen, maxnumlabels=maxnumlabels, maxnuminstrs=maxnuminstrs,
|
184 |
+
maxnumims=maxnumims,
|
185 |
+
transform=transform,
|
186 |
+
max_num_samples=max_num_samples,
|
187 |
+
use_lmdb=use_lmdb,
|
188 |
+
suff=suff)
|
189 |
+
|
190 |
+
data_loader = torch.utils.data.DataLoader(dataset=dataset,
|
191 |
+
batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
|
192 |
+
drop_last=drop_last, collate_fn=collate_fn, pin_memory=True)
|
193 |
+
return data_loader, dataset
|
src/demo.ipynb
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"## Inverse Cooking: Recipe Generation from Food Images"
|
8 |
+
]
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"cell_type": "code",
|
12 |
+
"execution_count": null,
|
13 |
+
"metadata": {},
|
14 |
+
"outputs": [],
|
15 |
+
"source": [
|
16 |
+
"import matplotlib.pyplot as plt\n",
|
17 |
+
"import torch\n",
|
18 |
+
"import torch.nn as nn\n",
|
19 |
+
"import numpy as np\n",
|
20 |
+
"import os\n",
|
21 |
+
"from args import get_parser\n",
|
22 |
+
"import pickle\n",
|
23 |
+
"from model import get_model\n",
|
24 |
+
"from torchvision import transforms\n",
|
25 |
+
"from utils.output_utils import prepare_output\n",
|
26 |
+
"from PIL import Image\n",
|
27 |
+
"import time"
|
28 |
+
]
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"cell_type": "markdown",
|
32 |
+
"metadata": {},
|
33 |
+
"source": [
|
34 |
+
"Set ```data_dir``` to the path including vocabularies and model checkpoint"
|
35 |
+
]
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"cell_type": "code",
|
39 |
+
"execution_count": null,
|
40 |
+
"metadata": {},
|
41 |
+
"outputs": [],
|
42 |
+
"source": [
|
43 |
+
"data_dir = '../data'"
|
44 |
+
]
|
45 |
+
},
|
46 |
+
{
|
47 |
+
"cell_type": "code",
|
48 |
+
"execution_count": null,
|
49 |
+
"metadata": {},
|
50 |
+
"outputs": [],
|
51 |
+
"source": [
|
52 |
+
"# code will run in gpu if available and if the flag is set to True, else it will run on cpu\n",
|
53 |
+
"use_gpu = False\n",
|
54 |
+
"device = torch.device('cuda' if torch.cuda.is_available() and use_gpu else 'cpu')\n",
|
55 |
+
"map_loc = None if torch.cuda.is_available() and use_gpu else 'cpu'"
|
56 |
+
]
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "code",
|
60 |
+
"execution_count": null,
|
61 |
+
"metadata": {},
|
62 |
+
"outputs": [],
|
63 |
+
"source": [
|
64 |
+
"# code below was used to save vocab files so that they can be loaded without Vocabulary class\n",
|
65 |
+
"#ingrs_vocab = pickle.load(open(os.path.join(data_dir, 'final_recipe1m_vocab_ingrs.pkl'), 'rb'))\n",
|
66 |
+
"#ingrs_vocab = [min(w, key=len) if not isinstance(w, str) else w for w in ingrs_vocab.idx2word.values()]\n",
|
67 |
+
"#vocab = pickle.load(open(os.path.join(data_dir, 'final_recipe1m_vocab_toks.pkl'), 'rb')).idx2word\n",
|
68 |
+
"#pickle.dump(ingrs_vocab, open('../demo/ingr_vocab.pkl', 'wb'))\n",
|
69 |
+
"#pickle.dump(vocab, open('../demo/instr_vocab.pkl', 'wb'))\n",
|
70 |
+
"\n",
|
71 |
+
"ingrs_vocab = pickle.load(open(os.path.join(data_dir, 'ingr_vocab.pkl'), 'rb'))\n",
|
72 |
+
"vocab = pickle.load(open(os.path.join(data_dir, 'instr_vocab.pkl'), 'rb'))\n",
|
73 |
+
"\n",
|
74 |
+
"ingr_vocab_size = len(ingrs_vocab)\n",
|
75 |
+
"instrs_vocab_size = len(vocab)\n",
|
76 |
+
"output_dim = instrs_vocab_size"
|
77 |
+
]
|
78 |
+
},
|
79 |
+
{
|
80 |
+
"cell_type": "code",
|
81 |
+
"execution_count": null,
|
82 |
+
"metadata": {},
|
83 |
+
"outputs": [],
|
84 |
+
"source": [
|
85 |
+
"print (instrs_vocab_size, ingr_vocab_size)"
|
86 |
+
]
|
87 |
+
},
|
88 |
+
{
|
89 |
+
"cell_type": "code",
|
90 |
+
"execution_count": null,
|
91 |
+
"metadata": {},
|
92 |
+
"outputs": [],
|
93 |
+
"source": [
|
94 |
+
"t = time.time()\n",
|
95 |
+
"import sys; sys.argv=['']; del sys\n",
|
96 |
+
"args = get_parser()\n",
|
97 |
+
"args.maxseqlen = 15\n",
|
98 |
+
"args.ingrs_only=False\n",
|
99 |
+
"model = get_model(args, ingr_vocab_size, instrs_vocab_size)\n",
|
100 |
+
"# Load the trained model parameters\n",
|
101 |
+
"model_path = os.path.join(data_dir, 'modelbest.ckpt')\n",
|
102 |
+
"model.load_state_dict(torch.load(model_path, map_location=map_loc))\n",
|
103 |
+
"model.to(device)\n",
|
104 |
+
"model.eval()\n",
|
105 |
+
"model.ingrs_only = False\n",
|
106 |
+
"model.recipe_only = False\n",
|
107 |
+
"print ('loaded model')\n",
|
108 |
+
"print (\"Elapsed time:\", time.time() -t)\n"
|
109 |
+
]
|
110 |
+
},
|
111 |
+
{
|
112 |
+
"cell_type": "code",
|
113 |
+
"execution_count": null,
|
114 |
+
"metadata": {},
|
115 |
+
"outputs": [],
|
116 |
+
"source": [
|
117 |
+
"transf_list_batch = []\n",
|
118 |
+
"transf_list_batch.append(transforms.ToTensor())\n",
|
119 |
+
"transf_list_batch.append(transforms.Normalize((0.485, 0.456, 0.406), \n",
|
120 |
+
" (0.229, 0.224, 0.225)))\n",
|
121 |
+
"to_input_transf = transforms.Compose(transf_list_batch)"
|
122 |
+
]
|
123 |
+
},
|
124 |
+
{
|
125 |
+
"cell_type": "code",
|
126 |
+
"execution_count": null,
|
127 |
+
"metadata": {},
|
128 |
+
"outputs": [],
|
129 |
+
"source": [
|
130 |
+
"greedy = [True, False, False, False]\n",
|
131 |
+
"beam = [-1, -1, -1, -1]\n",
|
132 |
+
"temperature = 1.0\n",
|
133 |
+
"numgens = len(greedy)"
|
134 |
+
]
|
135 |
+
},
|
136 |
+
{
|
137 |
+
"cell_type": "markdown",
|
138 |
+
"metadata": {},
|
139 |
+
"source": [
|
140 |
+
"Set ```use_urls = True``` to get recipes for images in ```demo_urls```. \n",
|
141 |
+
"\n",
|
142 |
+
"You can also set ```use_urls = False``` and get recipes for images in the path in ```data_dir/test_imgs```."
|
143 |
+
]
|
144 |
+
},
|
145 |
+
{
|
146 |
+
"cell_type": "code",
|
147 |
+
"execution_count": null,
|
148 |
+
"metadata": {
|
149 |
+
"scrolled": true
|
150 |
+
},
|
151 |
+
"outputs": [],
|
152 |
+
"source": [
|
153 |
+
"import requests\n",
|
154 |
+
"from io import BytesIO\n",
|
155 |
+
"import random\n",
|
156 |
+
"from collections import Counter\n",
|
157 |
+
"use_urls = False # set to true to load images from demo_urls instead of those in test_imgs folder\n",
|
158 |
+
"show_anyways = False #if True, it will show the recipe even if it's not valid\n",
|
159 |
+
"image_folder = os.path.join(data_dir, 'demo_imgs')\n",
|
160 |
+
"\n",
|
161 |
+
"if not use_urls:\n",
|
162 |
+
" demo_imgs = os.listdir(image_folder)\n",
|
163 |
+
" random.shuffle(demo_imgs)\n",
|
164 |
+
"\n",
|
165 |
+
"demo_urls = ['https://food.fnr.sndimg.com/content/dam/images/food/fullset/2013/12/9/0/FNK_Cheesecake_s4x3.jpg.rend.hgtvcom.826.620.suffix/1387411272847.jpeg',\n",
|
166 |
+
" 'https://www.196flavors.com/wp-content/uploads/2014/10/california-roll-3-FP.jpg']\n",
|
167 |
+
"\n",
|
168 |
+
"demo_files = demo_urls if use_urls else demo_imgs"
|
169 |
+
]
|
170 |
+
},
|
171 |
+
{
|
172 |
+
"cell_type": "code",
|
173 |
+
"execution_count": null,
|
174 |
+
"metadata": {},
|
175 |
+
"outputs": [],
|
176 |
+
"source": [
|
177 |
+
"for img_file in demo_files:\n",
|
178 |
+
" \n",
|
179 |
+
" if use_urls:\n",
|
180 |
+
" response = requests.get(img_file)\n",
|
181 |
+
" image = Image.open(BytesIO(response.content))\n",
|
182 |
+
" else:\n",
|
183 |
+
" image_path = os.path.join(image_folder, img_file)\n",
|
184 |
+
" image = Image.open(image_path).convert('RGB')\n",
|
185 |
+
" \n",
|
186 |
+
" transf_list = []\n",
|
187 |
+
" transf_list.append(transforms.Resize(256))\n",
|
188 |
+
" transf_list.append(transforms.CenterCrop(224))\n",
|
189 |
+
" transform = transforms.Compose(transf_list)\n",
|
190 |
+
" \n",
|
191 |
+
" image_transf = transform(image)\n",
|
192 |
+
" image_tensor = to_input_transf(image_transf).unsqueeze(0).to(device)\n",
|
193 |
+
" \n",
|
194 |
+
" plt.imshow(image_transf)\n",
|
195 |
+
" plt.axis('off')\n",
|
196 |
+
" plt.show()\n",
|
197 |
+
" plt.close()\n",
|
198 |
+
" \n",
|
199 |
+
" num_valid = 1\n",
|
200 |
+
" for i in range(numgens):\n",
|
201 |
+
" with torch.no_grad():\n",
|
202 |
+
" outputs = model.sample(image_tensor, greedy=greedy[i], \n",
|
203 |
+
" temperature=temperature, beam=beam[i], true_ingrs=None)\n",
|
204 |
+
" \n",
|
205 |
+
" ingr_ids = outputs['ingr_ids'].cpu().numpy()\n",
|
206 |
+
" recipe_ids = outputs['recipe_ids'].cpu().numpy()\n",
|
207 |
+
" \n",
|
208 |
+
" outs, valid = prepare_output(recipe_ids[0], ingr_ids[0], ingrs_vocab, vocab)\n",
|
209 |
+
" \n",
|
210 |
+
" if valid['is_valid'] or show_anyways:\n",
|
211 |
+
" \n",
|
212 |
+
" print ('RECIPE', num_valid)\n",
|
213 |
+
" num_valid+=1\n",
|
214 |
+
" #print (\"greedy:\", greedy[i], \"beam:\", beam[i])\n",
|
215 |
+
" \n",
|
216 |
+
" BOLD = '\\033[1m'\n",
|
217 |
+
" END = '\\033[0m'\n",
|
218 |
+
" print (BOLD + '\\nTitle:' + END,outs['title'])\n",
|
219 |
+
"\n",
|
220 |
+
" print (BOLD + '\\nIngredients:'+ END)\n",
|
221 |
+
" print (', '.join(outs['ingrs']))\n",
|
222 |
+
"\n",
|
223 |
+
" print (BOLD + '\\nInstructions:'+END)\n",
|
224 |
+
" print ('-'+'\\n-'.join(outs['recipe']))\n",
|
225 |
+
"\n",
|
226 |
+
" print ('='*20)\n",
|
227 |
+
"\n",
|
228 |
+
" else:\n",
|
229 |
+
" pass\n",
|
230 |
+
" print (\"Not a valid recipe!\")\n",
|
231 |
+
" print (\"Reason: \", valid['reason'])\n",
|
232 |
+
" "
|
233 |
+
]
|
234 |
+
},
|
235 |
+
{
|
236 |
+
"cell_type": "code",
|
237 |
+
"execution_count": null,
|
238 |
+
"metadata": {},
|
239 |
+
"outputs": [],
|
240 |
+
"source": []
|
241 |
+
},
|
242 |
+
{
|
243 |
+
"cell_type": "code",
|
244 |
+
"execution_count": null,
|
245 |
+
"metadata": {},
|
246 |
+
"outputs": [],
|
247 |
+
"source": []
|
248 |
+
}
|
249 |
+
],
|
250 |
+
"metadata": {
|
251 |
+
"kernelspec": {
|
252 |
+
"display_name": "Python 3",
|
253 |
+
"language": "python",
|
254 |
+
"name": "python3"
|
255 |
+
},
|
256 |
+
"language_info": {
|
257 |
+
"codemirror_mode": {
|
258 |
+
"name": "ipython",
|
259 |
+
"version": 3
|
260 |
+
},
|
261 |
+
"file_extension": ".py",
|
262 |
+
"mimetype": "text/x-python",
|
263 |
+
"name": "python",
|
264 |
+
"nbconvert_exporter": "python",
|
265 |
+
"pygments_lexer": "ipython3",
|
266 |
+
"version": "3.6.5"
|
267 |
+
}
|
268 |
+
},
|
269 |
+
"nbformat": 4,
|
270 |
+
"nbformat_minor": 2
|
271 |
+
}
|
src/demo.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
from args import get_parser
|
6 |
+
import pickle
|
7 |
+
from model import get_model
|
8 |
+
from torchvision import transforms
|
9 |
+
from utils.output_ing import prepare_output
|
10 |
+
from PIL import Image
|
11 |
+
from tqdm import tqdm
|
12 |
+
import time
|
13 |
+
import glob
|
14 |
+
|
15 |
+
|
16 |
+
# Set ```data_dir``` to the path including vocabularies and model checkpoint
|
17 |
+
model_dir = '../data'
|
18 |
+
image_folder = '../data/demo_imgs'
|
19 |
+
output_file = "../data/predicted_ingr.pkl"
|
20 |
+
|
21 |
+
# code will run in gpu if available and if the flag is set to True, else it will run on cpu
|
22 |
+
use_gpu = False
|
23 |
+
device = torch.device('cuda' if torch.cuda.is_available() and use_gpu else 'cpu')
|
24 |
+
map_loc = None if torch.cuda.is_available() and use_gpu else 'cpu'
|
25 |
+
|
26 |
+
# code below was used to save vocab files so that they can be loaded without Vocabulary class
|
27 |
+
#ingrs_vocab = pickle.load(open(os.path.join(data_dir, 'final_recipe1m_vocab_ingrs.pkl'), 'rb'))
|
28 |
+
#ingrs_vocab = [min(w, key=len) if not isinstance(w, str) else w for w in ingrs_vocab.idx2word.values()]
|
29 |
+
#vocab = pickle.load(open(os.path.join(data_dir, 'final_recipe1m_vocab_toks.pkl'), 'rb')).idx2word
|
30 |
+
#pickle.dump(ingrs_vocab, open('../demo/ingr_vocab.pkl', 'wb'))
|
31 |
+
#pickle.dump(vocab, open('../demo/instr_vocab.pkl', 'wb'))
|
32 |
+
|
33 |
+
ingrs_vocab = pickle.load(open(os.path.join(model_dir, 'ingr_vocab.pkl'), 'rb'))
|
34 |
+
vocab = pickle.load(open(os.path.join(model_dir, 'instr_vocab.pkl'), 'rb'))
|
35 |
+
|
36 |
+
ingr_vocab_size = len(ingrs_vocab)
|
37 |
+
instrs_vocab_size = len(vocab)
|
38 |
+
output_dim = instrs_vocab_size
|
39 |
+
|
40 |
+
print (instrs_vocab_size, ingr_vocab_size)
|
41 |
+
|
42 |
+
t = time.time()
|
43 |
+
|
44 |
+
args = get_parser()
|
45 |
+
args.maxseqlen = 15
|
46 |
+
args.ingrs_only=True
|
47 |
+
model = get_model(args, ingr_vocab_size, instrs_vocab_size)
|
48 |
+
# Load the trained model parameters
|
49 |
+
model_path = os.path.join(model_dir, 'modelbest.ckpt')
|
50 |
+
model.load_state_dict(torch.load(model_path, map_location=map_loc))
|
51 |
+
model.to(device)
|
52 |
+
model.eval()
|
53 |
+
model.ingrs_only = True
|
54 |
+
model.recipe_only = False
|
55 |
+
print ('loaded model')
|
56 |
+
print ("Elapsed time:", time.time() -t)
|
57 |
+
|
58 |
+
transf_list_batch = []
|
59 |
+
transf_list_batch.append(transforms.ToTensor())
|
60 |
+
transf_list_batch.append(transforms.Normalize((0.485, 0.456, 0.406),
|
61 |
+
(0.229, 0.224, 0.225)))
|
62 |
+
to_input_transf = transforms.Compose(transf_list_batch)
|
63 |
+
|
64 |
+
|
65 |
+
greedy = True
|
66 |
+
beam = -1
|
67 |
+
temperature = 1.0
|
68 |
+
|
69 |
+
# import requests
|
70 |
+
# from io import BytesIO
|
71 |
+
# import random
|
72 |
+
# from collections import Counter
|
73 |
+
# use_urls = False # set to true to load images from demo_urls instead of those in test_imgs folder
|
74 |
+
# show_anyways = False #if True, it will show the recipe even if it's not valid
|
75 |
+
# image_folder = os.path.join(data_dir, 'demo_imgs')
|
76 |
+
|
77 |
+
# if not use_urls:
|
78 |
+
# demo_imgs = os.listdir(image_folder)
|
79 |
+
# random.shuffle(demo_imgs)
|
80 |
+
|
81 |
+
# demo_urls = ['https://food.fnr.sndimg.com/content/dam/images/food/fullset/2013/12/9/0/FNK_Cheesecake_s4x3.jpg.rend.hgtvcom.826.620.suffix/1387411272847.jpeg',
|
82 |
+
# 'https://www.196flavors.com/wp-content/uploads/2014/10/california-roll-3-FP.jpg']
|
83 |
+
|
84 |
+
files_path = glob.glob(f"{image_folder}/*/*/*.jpg")
|
85 |
+
print(f"total data: {len(files_path)}")
|
86 |
+
|
87 |
+
res = []
|
88 |
+
for idx, img_file in tqdm(enumerate(files_path)):
|
89 |
+
# if use_urls:
|
90 |
+
# response = requests.get(img_file)
|
91 |
+
# image = Image.open(BytesIO(response.content))
|
92 |
+
# else:
|
93 |
+
image = Image.open(img_file).convert('RGB')
|
94 |
+
|
95 |
+
transf_list = []
|
96 |
+
transf_list.append(transforms.Resize(256))
|
97 |
+
transf_list.append(transforms.CenterCrop(224))
|
98 |
+
transform = transforms.Compose(transf_list)
|
99 |
+
|
100 |
+
image_transf = transform(image)
|
101 |
+
image_tensor = to_input_transf(image_transf).unsqueeze(0).to(device)
|
102 |
+
|
103 |
+
# plt.imshow(image_transf)
|
104 |
+
# plt.axis('off')
|
105 |
+
# plt.show()
|
106 |
+
# plt.close()
|
107 |
+
|
108 |
+
with torch.no_grad():
|
109 |
+
outputs = model.sample(image_tensor, greedy=greedy,
|
110 |
+
temperature=temperature, beam=beam, true_ingrs=None)
|
111 |
+
|
112 |
+
ingr_ids = outputs['ingr_ids'].cpu().numpy()
|
113 |
+
print(ingr_ids)
|
114 |
+
|
115 |
+
outs = prepare_output(ingr_ids[0], ingrs_vocab)
|
116 |
+
# print(ingrs_vocab.idx2word)
|
117 |
+
|
118 |
+
print(outs)
|
119 |
+
|
120 |
+
# print ('Pic ' + str(idx+1) + ':')
|
121 |
+
|
122 |
+
# print ('\nIngredients:')
|
123 |
+
# print (', '.join(outs['ingrs']))
|
124 |
+
|
125 |
+
# print ('='*20)
|
126 |
+
|
127 |
+
res.append({
|
128 |
+
"id": img_file,
|
129 |
+
"ingredients": outs['ingrs']
|
130 |
+
})
|
131 |
+
|
132 |
+
with open(output_file, "wb") as fp: #Pickling
|
133 |
+
pickle.dump(res, fp)
|
src/model.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import random
|
6 |
+
import numpy as np
|
7 |
+
from src.modules.encoder import EncoderCNN, EncoderLabels
|
8 |
+
from src.modules.transformer_decoder import DecoderTransformer
|
9 |
+
from src.modules.multihead_attention import MultiheadAttention
|
10 |
+
from src.utils.metrics import softIoU, MaskedCrossEntropyCriterion
|
11 |
+
import pickle
|
12 |
+
import os
|
13 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
14 |
+
|
15 |
+
|
16 |
+
def label2onehot(labels, pad_value):
|
17 |
+
|
18 |
+
# input labels to one hot vector
|
19 |
+
inp_ = torch.unsqueeze(labels, 2)
|
20 |
+
one_hot = torch.FloatTensor(labels.size(0), labels.size(1), pad_value + 1).zero_().to(device)
|
21 |
+
one_hot.scatter_(2, inp_, 1)
|
22 |
+
one_hot, _ = one_hot.max(dim=1)
|
23 |
+
# remove pad position
|
24 |
+
one_hot = one_hot[:, :-1]
|
25 |
+
# eos position is always 0
|
26 |
+
one_hot[:, 0] = 0
|
27 |
+
|
28 |
+
return one_hot
|
29 |
+
|
30 |
+
|
31 |
+
def mask_from_eos(ids, eos_value, mult_before=True):
|
32 |
+
mask = torch.ones(ids.size()).to(device).byte()
|
33 |
+
mask_aux = torch.ones(ids.size(0)).to(device).byte()
|
34 |
+
|
35 |
+
# find eos in ingredient prediction
|
36 |
+
for idx in range(ids.size(1)):
|
37 |
+
# force mask to have 1s in the first position to avoid division by 0 when predictions start with eos
|
38 |
+
if idx == 0:
|
39 |
+
continue
|
40 |
+
if mult_before:
|
41 |
+
mask[:, idx] = mask[:, idx] * mask_aux
|
42 |
+
mask_aux = mask_aux * (ids[:, idx] != eos_value)
|
43 |
+
else:
|
44 |
+
mask_aux = mask_aux * (ids[:, idx] != eos_value)
|
45 |
+
mask[:, idx] = mask[:, idx] * mask_aux
|
46 |
+
return mask
|
47 |
+
|
48 |
+
|
49 |
+
def get_model(args, ingr_vocab_size, instrs_vocab_size):
|
50 |
+
|
51 |
+
# build ingredients embedding
|
52 |
+
encoder_ingrs = EncoderLabels(args.embed_size, ingr_vocab_size,
|
53 |
+
args.dropout_encoder, scale_grad=False).to(device)
|
54 |
+
# build image model
|
55 |
+
encoder_image = EncoderCNN(args.embed_size, args.dropout_encoder, args.image_model)
|
56 |
+
|
57 |
+
decoder = DecoderTransformer(args.embed_size, instrs_vocab_size,
|
58 |
+
dropout=args.dropout_decoder_r, seq_length=args.maxseqlen,
|
59 |
+
num_instrs=args.maxnuminstrs,
|
60 |
+
attention_nheads=args.n_att, num_layers=args.transf_layers,
|
61 |
+
normalize_before=True,
|
62 |
+
normalize_inputs=False,
|
63 |
+
last_ln=False,
|
64 |
+
scale_embed_grad=False)
|
65 |
+
|
66 |
+
ingr_decoder = DecoderTransformer(args.embed_size, ingr_vocab_size, dropout=args.dropout_decoder_i,
|
67 |
+
seq_length=args.maxnumlabels,
|
68 |
+
num_instrs=1, attention_nheads=args.n_att_ingrs,
|
69 |
+
pos_embeddings=False,
|
70 |
+
num_layers=args.transf_layers_ingrs,
|
71 |
+
learned=False,
|
72 |
+
normalize_before=True,
|
73 |
+
normalize_inputs=True,
|
74 |
+
last_ln=True,
|
75 |
+
scale_embed_grad=False)
|
76 |
+
# recipe loss
|
77 |
+
criterion = MaskedCrossEntropyCriterion(ignore_index=[instrs_vocab_size-1], reduce=False)
|
78 |
+
|
79 |
+
# ingredients loss
|
80 |
+
label_loss = nn.BCELoss(reduce=False)
|
81 |
+
eos_loss = nn.BCELoss(reduce=False)
|
82 |
+
|
83 |
+
model = InverseCookingModel(encoder_ingrs, decoder, ingr_decoder, encoder_image,
|
84 |
+
crit=criterion, crit_ingr=label_loss, crit_eos=eos_loss,
|
85 |
+
pad_value=ingr_vocab_size-1,
|
86 |
+
ingrs_only=args.ingrs_only, recipe_only=args.recipe_only,
|
87 |
+
label_smoothing=args.label_smoothing_ingr)
|
88 |
+
|
89 |
+
return model
|
90 |
+
|
91 |
+
|
92 |
+
class InverseCookingModel(nn.Module):
|
93 |
+
def __init__(self, ingredient_encoder, recipe_decoder, ingr_decoder, image_encoder,
|
94 |
+
crit=None, crit_ingr=None, crit_eos=None,
|
95 |
+
pad_value=0, ingrs_only=True,
|
96 |
+
recipe_only=False, label_smoothing=0.0):
|
97 |
+
|
98 |
+
super(InverseCookingModel, self).__init__()
|
99 |
+
|
100 |
+
self.ingredient_encoder = ingredient_encoder
|
101 |
+
self.recipe_decoder = recipe_decoder
|
102 |
+
self.image_encoder = image_encoder
|
103 |
+
self.ingredient_decoder = ingr_decoder
|
104 |
+
self.crit = crit
|
105 |
+
self.crit_ingr = crit_ingr
|
106 |
+
self.pad_value = pad_value
|
107 |
+
self.ingrs_only = ingrs_only
|
108 |
+
self.recipe_only = recipe_only
|
109 |
+
self.crit_eos = crit_eos
|
110 |
+
self.label_smoothing = label_smoothing
|
111 |
+
|
112 |
+
def forward(self, img_inputs, captions, target_ingrs,
|
113 |
+
sample=False, keep_cnn_gradients=False):
|
114 |
+
|
115 |
+
if sample:
|
116 |
+
return self.sample(img_inputs, greedy=True)
|
117 |
+
|
118 |
+
targets = captions[:, 1:]
|
119 |
+
targets = targets.contiguous().view(-1)
|
120 |
+
|
121 |
+
img_features = self.image_encoder(img_inputs, keep_cnn_gradients)
|
122 |
+
|
123 |
+
losses = {}
|
124 |
+
target_one_hot = label2onehot(target_ingrs, self.pad_value)
|
125 |
+
target_one_hot_smooth = label2onehot(target_ingrs, self.pad_value)
|
126 |
+
|
127 |
+
# ingredient prediction
|
128 |
+
if not self.recipe_only:
|
129 |
+
target_one_hot_smooth[target_one_hot_smooth == 1] = (1-self.label_smoothing)
|
130 |
+
target_one_hot_smooth[target_one_hot_smooth == 0] = self.label_smoothing / target_one_hot_smooth.size(-1)
|
131 |
+
|
132 |
+
# decode ingredients with transformer
|
133 |
+
# autoregressive mode for ingredient decoder
|
134 |
+
ingr_ids, ingr_logits = self.ingredient_decoder.sample(None, None, greedy=True,
|
135 |
+
temperature=1.0, img_features=img_features,
|
136 |
+
first_token_value=0, replacement=False)
|
137 |
+
|
138 |
+
ingr_logits = torch.nn.functional.softmax(ingr_logits, dim=-1)
|
139 |
+
|
140 |
+
# find idxs for eos ingredient
|
141 |
+
# eos probability is the one assigned to the first position of the softmax
|
142 |
+
eos = ingr_logits[:, :, 0]
|
143 |
+
target_eos = ((target_ingrs == 0) ^ (target_ingrs == self.pad_value))
|
144 |
+
|
145 |
+
eos_pos = (target_ingrs == 0)
|
146 |
+
eos_head = ((target_ingrs != self.pad_value) & (target_ingrs != 0))
|
147 |
+
|
148 |
+
# select transformer steps to pool from
|
149 |
+
mask_perminv = mask_from_eos(target_ingrs, eos_value=0, mult_before=False)
|
150 |
+
ingr_probs = ingr_logits * mask_perminv.float().unsqueeze(-1)
|
151 |
+
|
152 |
+
ingr_probs, _ = torch.max(ingr_probs, dim=1)
|
153 |
+
|
154 |
+
# ignore predicted ingredients after eos in ground truth
|
155 |
+
ingr_ids[mask_perminv == 0] = self.pad_value
|
156 |
+
|
157 |
+
ingr_loss = self.crit_ingr(ingr_probs, target_one_hot_smooth)
|
158 |
+
ingr_loss = torch.mean(ingr_loss, dim=-1)
|
159 |
+
|
160 |
+
losses['ingr_loss'] = ingr_loss
|
161 |
+
|
162 |
+
# cardinality penalty
|
163 |
+
losses['card_penalty'] = torch.abs((ingr_probs*target_one_hot).sum(1) - target_one_hot.sum(1)) + \
|
164 |
+
torch.abs((ingr_probs*(1-target_one_hot)).sum(1))
|
165 |
+
|
166 |
+
eos_loss = self.crit_eos(eos, target_eos.float())
|
167 |
+
|
168 |
+
mult = 1/2
|
169 |
+
# eos loss is only computed for timesteps <= t_eos and equally penalizes 0s and 1s
|
170 |
+
losses['eos_loss'] = mult*(eos_loss * eos_pos.float()).sum(1) / (eos_pos.float().sum(1) + 1e-6) + \
|
171 |
+
mult*(eos_loss * eos_head.float()).sum(1) / (eos_head.float().sum(1) + 1e-6)
|
172 |
+
# iou
|
173 |
+
pred_one_hot = label2onehot(ingr_ids, self.pad_value)
|
174 |
+
# iou sample during training is computed using the true eos position
|
175 |
+
losses['iou'] = softIoU(pred_one_hot, target_one_hot)
|
176 |
+
|
177 |
+
if self.ingrs_only:
|
178 |
+
return losses
|
179 |
+
|
180 |
+
# encode ingredients
|
181 |
+
target_ingr_feats = self.ingredient_encoder(target_ingrs)
|
182 |
+
target_ingr_mask = mask_from_eos(target_ingrs, eos_value=0, mult_before=False)
|
183 |
+
|
184 |
+
target_ingr_mask = target_ingr_mask.float().unsqueeze(1)
|
185 |
+
|
186 |
+
outputs, ids = self.recipe_decoder(target_ingr_feats, target_ingr_mask, captions, img_features)
|
187 |
+
|
188 |
+
outputs = outputs[:, :-1, :].contiguous()
|
189 |
+
outputs = outputs.view(outputs.size(0) * outputs.size(1), -1)
|
190 |
+
|
191 |
+
loss = self.crit(outputs, targets)
|
192 |
+
|
193 |
+
losses['recipe_loss'] = loss
|
194 |
+
|
195 |
+
return losses
|
196 |
+
|
197 |
+
def sample(self, img_inputs, greedy=True, temperature=1.0, beam=-1, true_ingrs=None):
|
198 |
+
|
199 |
+
outputs = dict()
|
200 |
+
|
201 |
+
img_features = self.image_encoder(img_inputs)
|
202 |
+
|
203 |
+
if not self.recipe_only:
|
204 |
+
ingr_ids, ingr_probs = self.ingredient_decoder.sample(None, None, greedy=True, temperature=temperature,
|
205 |
+
beam=-1,
|
206 |
+
img_features=img_features, first_token_value=0,
|
207 |
+
replacement=False)
|
208 |
+
|
209 |
+
# mask ingredients after finding eos
|
210 |
+
sample_mask = mask_from_eos(ingr_ids, eos_value=0, mult_before=False)
|
211 |
+
ingr_ids[sample_mask == 0] = self.pad_value
|
212 |
+
|
213 |
+
outputs['ingr_ids'] = ingr_ids
|
214 |
+
outputs['ingr_probs'] = ingr_probs.data
|
215 |
+
|
216 |
+
mask = sample_mask
|
217 |
+
input_mask = mask.float().unsqueeze(1)
|
218 |
+
input_feats = self.ingredient_encoder(ingr_ids)
|
219 |
+
|
220 |
+
if self.ingrs_only:
|
221 |
+
return outputs
|
222 |
+
|
223 |
+
# option during sampling to use the real ingredients and not the predicted ones to infer the recipe
|
224 |
+
if true_ingrs is not None:
|
225 |
+
input_mask = mask_from_eos(true_ingrs, eos_value=0, mult_before=False)
|
226 |
+
true_ingrs[input_mask == 0] = self.pad_value
|
227 |
+
input_feats = self.ingredient_encoder(true_ingrs)
|
228 |
+
input_mask = input_mask.unsqueeze(1)
|
229 |
+
|
230 |
+
ids, probs = self.recipe_decoder.sample(input_feats, input_mask, greedy, temperature, beam, img_features, 0,
|
231 |
+
last_token_value=1)
|
232 |
+
|
233 |
+
outputs['recipe_probs'] = probs.data
|
234 |
+
outputs['recipe_ids'] = ids
|
235 |
+
|
236 |
+
return outputs
|
src/model1_inf.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
from src.args import get_parser
|
6 |
+
import pickle
|
7 |
+
from src.model import get_model
|
8 |
+
from torchvision import transforms
|
9 |
+
from src.utils.output_ing import prepare_output
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
13 |
+
map_loc = None if torch.cuda.is_available() else 'cpu'
|
14 |
+
|
15 |
+
def im2ingr(image, ingrs_vocab, model):
|
16 |
+
transf_list_batch = []
|
17 |
+
transf_list_batch.append(transforms.ToTensor())
|
18 |
+
transf_list_batch.append(transforms.Normalize((0.485, 0.456, 0.406),
|
19 |
+
(0.229, 0.224, 0.225)))
|
20 |
+
to_input_transf = transforms.Compose(transf_list_batch)
|
21 |
+
|
22 |
+
greedy = True
|
23 |
+
beam = -1
|
24 |
+
temperature = 1.0
|
25 |
+
|
26 |
+
transf_list = []
|
27 |
+
transf_list.append(transforms.Resize(256))
|
28 |
+
transf_list.append(transforms.CenterCrop(224))
|
29 |
+
transform = transforms.Compose(transf_list)
|
30 |
+
|
31 |
+
image_transf = transform(image)
|
32 |
+
image_tensor = to_input_transf(image_transf).unsqueeze(0).to(device)
|
33 |
+
|
34 |
+
with torch.no_grad():
|
35 |
+
outputs = model.sample(image_tensor, greedy=greedy,
|
36 |
+
temperature=temperature, beam=beam, true_ingrs=None)
|
37 |
+
|
38 |
+
ingr_ids = outputs['ingr_ids'].cpu().numpy()
|
39 |
+
outs = prepare_output(ingr_ids[0], ingrs_vocab)
|
40 |
+
|
41 |
+
return outs['ingrs']
|
42 |
+
|
43 |
+
|
src/modules/encoder.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
|
3 |
+
from torchvision.models import resnet18, resnet50, resnet101, resnet152, vgg16, vgg19, inception_v3
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import random
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
|
10 |
+
class EncoderCNN(nn.Module):
|
11 |
+
def __init__(self, embed_size, dropout=0.5, image_model='resnet101', pretrained=True):
|
12 |
+
"""Load the pretrained ResNet-152 and replace top fc layer."""
|
13 |
+
super(EncoderCNN, self).__init__()
|
14 |
+
resnet = globals()[image_model](pretrained=pretrained)
|
15 |
+
modules = list(resnet.children())[:-2] # delete the last fc layer.
|
16 |
+
self.resnet = nn.Sequential(*modules)
|
17 |
+
|
18 |
+
self.linear = nn.Sequential(nn.Conv2d(resnet.fc.in_features, embed_size, kernel_size=1, padding=0),
|
19 |
+
nn.Dropout2d(dropout))
|
20 |
+
|
21 |
+
def forward(self, images, keep_cnn_gradients=False):
|
22 |
+
"""Extract feature vectors from input images."""
|
23 |
+
|
24 |
+
if keep_cnn_gradients:
|
25 |
+
raw_conv_feats = self.resnet(images)
|
26 |
+
else:
|
27 |
+
with torch.no_grad():
|
28 |
+
raw_conv_feats = self.resnet(images)
|
29 |
+
features = self.linear(raw_conv_feats)
|
30 |
+
features = features.view(features.size(0), features.size(1), -1)
|
31 |
+
|
32 |
+
return features
|
33 |
+
|
34 |
+
|
35 |
+
class EncoderLabels(nn.Module):
|
36 |
+
def __init__(self, embed_size, num_classes, dropout=0.5, embed_weights=None, scale_grad=False):
|
37 |
+
|
38 |
+
super(EncoderLabels, self).__init__()
|
39 |
+
embeddinglayer = nn.Embedding(num_classes, embed_size, padding_idx=num_classes-1, scale_grad_by_freq=scale_grad)
|
40 |
+
if embed_weights is not None:
|
41 |
+
embeddinglayer.weight.data.copy_(embed_weights)
|
42 |
+
self.pad_value = num_classes - 1
|
43 |
+
self.linear = embeddinglayer
|
44 |
+
self.dropout = dropout
|
45 |
+
self.embed_size = embed_size
|
46 |
+
|
47 |
+
def forward(self, x, onehot_flag=False):
|
48 |
+
|
49 |
+
if onehot_flag:
|
50 |
+
embeddings = torch.matmul(x, self.linear.weight)
|
51 |
+
else:
|
52 |
+
embeddings = self.linear(x)
|
53 |
+
|
54 |
+
embeddings = nn.functional.dropout(embeddings, p=self.dropout, training=self.training)
|
55 |
+
embeddings = embeddings.permute(0, 2, 1).contiguous()
|
56 |
+
|
57 |
+
return embeddings
|
src/modules/multihead_attention.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2017-present, Facebook, Inc.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the LICENSE file in
|
5 |
+
# https://github.com/pytorch/fairseq. An additional grant of patent rights
|
6 |
+
# can be found in the PATENTS file in the same directory.
|
7 |
+
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from torch import nn
|
11 |
+
from torch.nn import Parameter
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
from src.modules.utils import fill_with_neg_inf, get_incremental_state, set_incremental_state
|
15 |
+
|
16 |
+
|
17 |
+
class MultiheadAttention(nn.Module):
|
18 |
+
"""Multi-headed attention.
|
19 |
+
See "Attention Is All You Need" for more details.
|
20 |
+
"""
|
21 |
+
def __init__(self, embed_dim, num_heads, dropout=0., bias=True):
|
22 |
+
super().__init__()
|
23 |
+
self.embed_dim = embed_dim
|
24 |
+
self.num_heads = num_heads
|
25 |
+
self.dropout = dropout
|
26 |
+
self.head_dim = embed_dim // num_heads
|
27 |
+
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
28 |
+
self.scaling = self.head_dim**-0.5
|
29 |
+
self._mask = None
|
30 |
+
|
31 |
+
self.in_proj_weight = Parameter(torch.Tensor(3*embed_dim, embed_dim))
|
32 |
+
if bias:
|
33 |
+
self.in_proj_bias = Parameter(torch.Tensor(3*embed_dim))
|
34 |
+
else:
|
35 |
+
self.register_parameter('in_proj_bias', None)
|
36 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
37 |
+
|
38 |
+
self.reset_parameters()
|
39 |
+
|
40 |
+
def reset_parameters(self):
|
41 |
+
nn.init.xavier_uniform_(self.in_proj_weight)
|
42 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
43 |
+
if self.in_proj_bias is not None:
|
44 |
+
nn.init.constant_(self.in_proj_bias, 0.)
|
45 |
+
nn.init.constant_(self.out_proj.bias, 0.)
|
46 |
+
|
47 |
+
def forward(self, query, key, value, mask_future_timesteps=False,
|
48 |
+
key_padding_mask=None, incremental_state=None,
|
49 |
+
need_weights=True, static_kv=False):
|
50 |
+
"""Input shape: Time x Batch x Channel
|
51 |
+
Self-attention can be implemented by passing in the same arguments for
|
52 |
+
query, key and value. Future timesteps can be masked with the
|
53 |
+
`mask_future_timesteps` argument. Padding elements can be excluded from
|
54 |
+
the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
|
55 |
+
batch x src_len, where padding elements are indicated by 1s.
|
56 |
+
"""
|
57 |
+
|
58 |
+
qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
|
59 |
+
kv_same = key.data_ptr() == value.data_ptr()
|
60 |
+
|
61 |
+
tgt_len, bsz, embed_dim = query.size()
|
62 |
+
assert embed_dim == self.embed_dim
|
63 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
64 |
+
assert key.size() == value.size()
|
65 |
+
|
66 |
+
if incremental_state is not None:
|
67 |
+
saved_state = self._get_input_buffer(incremental_state)
|
68 |
+
if 'prev_key' in saved_state:
|
69 |
+
# previous time steps are cached - no need to recompute
|
70 |
+
# key and value if they are static
|
71 |
+
if static_kv:
|
72 |
+
assert kv_same and not qkv_same
|
73 |
+
key = value = None
|
74 |
+
else:
|
75 |
+
saved_state = None
|
76 |
+
|
77 |
+
if qkv_same:
|
78 |
+
# self-attention
|
79 |
+
q, k, v = self.in_proj_qkv(query)
|
80 |
+
elif kv_same:
|
81 |
+
# encoder-decoder attention
|
82 |
+
q = self.in_proj_q(query)
|
83 |
+
if key is None:
|
84 |
+
assert value is None
|
85 |
+
# this will allow us to concat it with previous value and get
|
86 |
+
# just get the previous value
|
87 |
+
k = v = q.new(0)
|
88 |
+
else:
|
89 |
+
k, v = self.in_proj_kv(key)
|
90 |
+
else:
|
91 |
+
q = self.in_proj_q(query)
|
92 |
+
k = self.in_proj_k(key)
|
93 |
+
v = self.in_proj_v(value)
|
94 |
+
q *= self.scaling
|
95 |
+
|
96 |
+
if saved_state is not None:
|
97 |
+
if 'prev_key' in saved_state:
|
98 |
+
k = torch.cat((saved_state['prev_key'], k), dim=0)
|
99 |
+
if 'prev_value' in saved_state:
|
100 |
+
v = torch.cat((saved_state['prev_value'], v), dim=0)
|
101 |
+
saved_state['prev_key'] = k
|
102 |
+
saved_state['prev_value'] = v
|
103 |
+
self._set_input_buffer(incremental_state, saved_state)
|
104 |
+
|
105 |
+
src_len = k.size(0)
|
106 |
+
|
107 |
+
if key_padding_mask is not None:
|
108 |
+
assert key_padding_mask.size(0) == bsz
|
109 |
+
assert key_padding_mask.size(1) == src_len
|
110 |
+
|
111 |
+
q = q.contiguous().view(tgt_len, bsz*self.num_heads, self.head_dim).transpose(0, 1)
|
112 |
+
k = k.contiguous().view(src_len, bsz*self.num_heads, self.head_dim).transpose(0, 1)
|
113 |
+
v = v.contiguous().view(src_len, bsz*self.num_heads, self.head_dim).transpose(0, 1)
|
114 |
+
|
115 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
116 |
+
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
117 |
+
|
118 |
+
# only apply masking at training time (when incremental state is None)
|
119 |
+
if mask_future_timesteps and incremental_state is None:
|
120 |
+
assert query.size() == key.size(), \
|
121 |
+
'mask_future_timesteps only applies to self-attention'
|
122 |
+
attn_weights += self.buffered_mask(attn_weights).unsqueeze(0)
|
123 |
+
if key_padding_mask is not None:
|
124 |
+
# don't attend to padding symbols
|
125 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
126 |
+
attn_weights = attn_weights.float().masked_fill(
|
127 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
128 |
+
float('-inf'),
|
129 |
+
).type_as(attn_weights) # FP16 support: cast to float and back
|
130 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
131 |
+
|
132 |
+
attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights)
|
133 |
+
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
|
134 |
+
|
135 |
+
attn = torch.bmm(attn_weights, v)
|
136 |
+
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
137 |
+
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
138 |
+
attn = self.out_proj(attn)
|
139 |
+
|
140 |
+
# average attention weights over heads
|
141 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
142 |
+
attn_weights = attn_weights.sum(dim=1) / self.num_heads
|
143 |
+
|
144 |
+
return attn, attn_weights
|
145 |
+
|
146 |
+
def in_proj_qkv(self, query):
|
147 |
+
return self._in_proj(query).chunk(3, dim=-1)
|
148 |
+
|
149 |
+
def in_proj_kv(self, key):
|
150 |
+
return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1)
|
151 |
+
|
152 |
+
def in_proj_q(self, query):
|
153 |
+
return self._in_proj(query, end=self.embed_dim)
|
154 |
+
|
155 |
+
def in_proj_k(self, key):
|
156 |
+
return self._in_proj(key, start=self.embed_dim, end=2*self.embed_dim)
|
157 |
+
|
158 |
+
def in_proj_v(self, value):
|
159 |
+
return self._in_proj(value, start=2*self.embed_dim)
|
160 |
+
|
161 |
+
def _in_proj(self, input, start=None, end=None):
|
162 |
+
weight = self.in_proj_weight
|
163 |
+
bias = self.in_proj_bias
|
164 |
+
if end is not None:
|
165 |
+
weight = weight[:end, :]
|
166 |
+
if bias is not None:
|
167 |
+
bias = bias[:end]
|
168 |
+
if start is not None:
|
169 |
+
weight = weight[start:, :]
|
170 |
+
if bias is not None:
|
171 |
+
bias = bias[start:]
|
172 |
+
return F.linear(input, weight, bias)
|
173 |
+
|
174 |
+
def buffered_mask(self, tensor):
|
175 |
+
dim = tensor.size(-1)
|
176 |
+
if self._mask is None:
|
177 |
+
self._mask = torch.triu(fill_with_neg_inf(tensor.new(dim, dim)), 1)
|
178 |
+
if self._mask.size(0) < dim:
|
179 |
+
self._mask = torch.triu(fill_with_neg_inf(self._mask.resize_(dim, dim)), 1)
|
180 |
+
return self._mask[:dim, :dim]
|
181 |
+
|
182 |
+
def reorder_incremental_state(self, incremental_state, new_order):
|
183 |
+
"""Reorder buffered internal state (for incremental generation)."""
|
184 |
+
input_buffer = self._get_input_buffer(incremental_state)
|
185 |
+
if input_buffer is not None:
|
186 |
+
for k in input_buffer.keys():
|
187 |
+
input_buffer[k] = input_buffer[k].index_select(1, new_order)
|
188 |
+
self._set_input_buffer(incremental_state, input_buffer)
|
189 |
+
|
190 |
+
def _get_input_buffer(self, incremental_state):
|
191 |
+
return get_incremental_state(
|
192 |
+
self,
|
193 |
+
incremental_state,
|
194 |
+
'attn_state',
|
195 |
+
) or {}
|
196 |
+
|
197 |
+
def _set_input_buffer(self, incremental_state, buffer):
|
198 |
+
set_incremental_state(
|
199 |
+
self,
|
200 |
+
incremental_state,
|
201 |
+
'attn_state',
|
202 |
+
buffer,
|
203 |
+
)
|
src/modules/transformer_decoder.py
ADDED
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
|
3 |
+
# Code adapted from https://github.com/pytorch/fairseq
|
4 |
+
# Copyright (c) 2017-present, Facebook, Inc.
|
5 |
+
# All rights reserved.
|
6 |
+
#
|
7 |
+
# This source code is licensed under the license found in the LICENSE file in
|
8 |
+
# https://github.com/pytorch/fairseq. An additional grant of patent rights
|
9 |
+
# can be found in the PATENTS file in the same directory.
|
10 |
+
|
11 |
+
import math
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from torch.nn.modules.utils import _single
|
16 |
+
import src.modules.utils as utils
|
17 |
+
from src.modules.multihead_attention import MultiheadAttention
|
18 |
+
import numpy as np
|
19 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
20 |
+
import copy
|
21 |
+
|
22 |
+
|
23 |
+
def make_positions(tensor, padding_idx, left_pad):
|
24 |
+
"""Replace non-padding symbols with their position numbers.
|
25 |
+
Position numbers begin at padding_idx+1.
|
26 |
+
Padding symbols are ignored, but it is necessary to specify whether padding
|
27 |
+
is added on the left side (left_pad=True) or right side (left_pad=False).
|
28 |
+
"""
|
29 |
+
|
30 |
+
# creates tensor from scratch - to avoid multigpu issues
|
31 |
+
max_pos = padding_idx + 1 + tensor.size(1)
|
32 |
+
#if not hasattr(make_positions, 'range_buf'):
|
33 |
+
range_buf = tensor.new()
|
34 |
+
#make_positions.range_buf = make_positions.range_buf.type_as(tensor)
|
35 |
+
if range_buf.numel() < max_pos:
|
36 |
+
torch.arange(padding_idx + 1, max_pos, out=range_buf)
|
37 |
+
mask = tensor.ne(padding_idx)
|
38 |
+
positions = range_buf[:tensor.size(1)].expand_as(tensor)
|
39 |
+
if left_pad:
|
40 |
+
positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1)
|
41 |
+
|
42 |
+
out = tensor.clone()
|
43 |
+
out = out.masked_scatter_(mask,positions[mask])
|
44 |
+
return out
|
45 |
+
|
46 |
+
|
47 |
+
class LearnedPositionalEmbedding(nn.Embedding):
|
48 |
+
"""This module learns positional embeddings up to a fixed maximum size.
|
49 |
+
Padding symbols are ignored, but it is necessary to specify whether padding
|
50 |
+
is added on the left side (left_pad=True) or right side (left_pad=False).
|
51 |
+
"""
|
52 |
+
|
53 |
+
def __init__(self, num_embeddings, embedding_dim, padding_idx, left_pad):
|
54 |
+
super().__init__(num_embeddings, embedding_dim, padding_idx)
|
55 |
+
self.left_pad = left_pad
|
56 |
+
nn.init.normal_(self.weight, mean=0, std=embedding_dim ** -0.5)
|
57 |
+
|
58 |
+
def forward(self, input, incremental_state=None):
|
59 |
+
"""Input is expected to be of size [bsz x seqlen]."""
|
60 |
+
if incremental_state is not None:
|
61 |
+
# positions is the same for every token when decoding a single step
|
62 |
+
|
63 |
+
positions = input.data.new(1, 1).fill_(self.padding_idx + input.size(1))
|
64 |
+
else:
|
65 |
+
|
66 |
+
positions = make_positions(input.data, self.padding_idx, self.left_pad)
|
67 |
+
return super().forward(positions)
|
68 |
+
|
69 |
+
def max_positions(self):
|
70 |
+
"""Maximum number of supported positions."""
|
71 |
+
return self.num_embeddings - self.padding_idx - 1
|
72 |
+
|
73 |
+
class SinusoidalPositionalEmbedding(nn.Module):
|
74 |
+
"""This module produces sinusoidal positional embeddings of any length.
|
75 |
+
Padding symbols are ignored, but it is necessary to specify whether padding
|
76 |
+
is added on the left side (left_pad=True) or right side (left_pad=False).
|
77 |
+
"""
|
78 |
+
|
79 |
+
def __init__(self, embedding_dim, padding_idx, left_pad, init_size=1024):
|
80 |
+
super().__init__()
|
81 |
+
self.embedding_dim = embedding_dim
|
82 |
+
self.padding_idx = padding_idx
|
83 |
+
self.left_pad = left_pad
|
84 |
+
self.weights = SinusoidalPositionalEmbedding.get_embedding(
|
85 |
+
init_size,
|
86 |
+
embedding_dim,
|
87 |
+
padding_idx,
|
88 |
+
)
|
89 |
+
self.register_buffer('_float_tensor', torch.FloatTensor())
|
90 |
+
|
91 |
+
@staticmethod
|
92 |
+
def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
|
93 |
+
"""Build sinusoidal embeddings.
|
94 |
+
This matches the implementation in tensor2tensor, but differs slightly
|
95 |
+
from the description in Section 3.5 of "Attention Is All You Need".
|
96 |
+
"""
|
97 |
+
half_dim = embedding_dim // 2
|
98 |
+
emb = math.log(10000) / (half_dim - 1)
|
99 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
|
100 |
+
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
|
101 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
|
102 |
+
if embedding_dim % 2 == 1:
|
103 |
+
# zero pad
|
104 |
+
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
|
105 |
+
if padding_idx is not None:
|
106 |
+
emb[padding_idx, :] = 0
|
107 |
+
return emb
|
108 |
+
|
109 |
+
def forward(self, input, incremental_state=None):
|
110 |
+
"""Input is expected to be of size [bsz x seqlen]."""
|
111 |
+
# recompute/expand embeddings if needed
|
112 |
+
bsz, seq_len = input.size()
|
113 |
+
max_pos = self.padding_idx + 1 + seq_len
|
114 |
+
if self.weights is None or max_pos > self.weights.size(0):
|
115 |
+
self.weights = SinusoidalPositionalEmbedding.get_embedding(
|
116 |
+
max_pos,
|
117 |
+
self.embedding_dim,
|
118 |
+
self.padding_idx,
|
119 |
+
)
|
120 |
+
self.weights = self.weights.type_as(self._float_tensor)
|
121 |
+
|
122 |
+
if incremental_state is not None:
|
123 |
+
# positions is the same for every token when decoding a single step
|
124 |
+
return self.weights[self.padding_idx + seq_len, :].expand(bsz, 1, -1)
|
125 |
+
|
126 |
+
positions = make_positions(input.data, self.padding_idx, self.left_pad)
|
127 |
+
return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
|
128 |
+
|
129 |
+
def max_positions(self):
|
130 |
+
"""Maximum number of supported positions."""
|
131 |
+
return int(1e5) # an arbitrary large number
|
132 |
+
|
133 |
+
class TransformerDecoderLayer(nn.Module):
|
134 |
+
"""Decoder layer block."""
|
135 |
+
|
136 |
+
def __init__(self, embed_dim, n_att, dropout=0.5, normalize_before=True, last_ln=False):
|
137 |
+
super().__init__()
|
138 |
+
|
139 |
+
self.embed_dim = embed_dim
|
140 |
+
self.dropout = dropout
|
141 |
+
self.relu_dropout = dropout
|
142 |
+
self.normalize_before = normalize_before
|
143 |
+
num_layer_norm = 3
|
144 |
+
|
145 |
+
# self-attention on generated recipe
|
146 |
+
self.self_attn = MultiheadAttention(
|
147 |
+
self.embed_dim, n_att,
|
148 |
+
dropout=dropout,
|
149 |
+
)
|
150 |
+
|
151 |
+
self.cond_att = MultiheadAttention(
|
152 |
+
self.embed_dim, n_att,
|
153 |
+
dropout=dropout,
|
154 |
+
)
|
155 |
+
|
156 |
+
self.fc1 = Linear(self.embed_dim, self.embed_dim)
|
157 |
+
self.fc2 = Linear(self.embed_dim, self.embed_dim)
|
158 |
+
self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for i in range(num_layer_norm)])
|
159 |
+
self.use_last_ln = last_ln
|
160 |
+
if self.use_last_ln:
|
161 |
+
self.last_ln = LayerNorm(self.embed_dim)
|
162 |
+
|
163 |
+
def forward(self, x, ingr_features, ingr_mask, incremental_state, img_features):
|
164 |
+
|
165 |
+
# self attention
|
166 |
+
residual = x
|
167 |
+
x = self.maybe_layer_norm(0, x, before=True)
|
168 |
+
x, _ = self.self_attn(
|
169 |
+
query=x,
|
170 |
+
key=x,
|
171 |
+
value=x,
|
172 |
+
mask_future_timesteps=True,
|
173 |
+
incremental_state=incremental_state,
|
174 |
+
need_weights=False,
|
175 |
+
)
|
176 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
177 |
+
x = residual + x
|
178 |
+
x = self.maybe_layer_norm(0, x, after=True)
|
179 |
+
|
180 |
+
residual = x
|
181 |
+
x = self.maybe_layer_norm(1, x, before=True)
|
182 |
+
|
183 |
+
# attention
|
184 |
+
if ingr_features is None:
|
185 |
+
|
186 |
+
x, _ = self.cond_att(query=x,
|
187 |
+
key=img_features,
|
188 |
+
value=img_features,
|
189 |
+
key_padding_mask=None,
|
190 |
+
incremental_state=incremental_state,
|
191 |
+
static_kv=True,
|
192 |
+
)
|
193 |
+
elif img_features is None:
|
194 |
+
x, _ = self.cond_att(query=x,
|
195 |
+
key=ingr_features,
|
196 |
+
value=ingr_features,
|
197 |
+
key_padding_mask=ingr_mask,
|
198 |
+
incremental_state=incremental_state,
|
199 |
+
static_kv=True,
|
200 |
+
)
|
201 |
+
|
202 |
+
|
203 |
+
else:
|
204 |
+
# attention on concatenation of encoder_out and encoder_aux, query self attn (x)
|
205 |
+
kv = torch.cat((img_features, ingr_features), 0)
|
206 |
+
mask = torch.cat((torch.zeros(img_features.shape[1], img_features.shape[0], dtype=torch.uint8).to(device),
|
207 |
+
ingr_mask), 1)
|
208 |
+
x, _ = self.cond_att(query=x,
|
209 |
+
key=kv,
|
210 |
+
value=kv,
|
211 |
+
key_padding_mask=mask,
|
212 |
+
incremental_state=incremental_state,
|
213 |
+
static_kv=True,
|
214 |
+
)
|
215 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
216 |
+
x = residual + x
|
217 |
+
x = self.maybe_layer_norm(1, x, after=True)
|
218 |
+
|
219 |
+
residual = x
|
220 |
+
x = self.maybe_layer_norm(-1, x, before=True)
|
221 |
+
x = F.relu(self.fc1(x))
|
222 |
+
x = F.dropout(x, p=self.relu_dropout, training=self.training)
|
223 |
+
x = self.fc2(x)
|
224 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
225 |
+
x = residual + x
|
226 |
+
x = self.maybe_layer_norm(-1, x, after=True)
|
227 |
+
|
228 |
+
if self.use_last_ln:
|
229 |
+
x = self.last_ln(x)
|
230 |
+
|
231 |
+
return x
|
232 |
+
|
233 |
+
def maybe_layer_norm(self, i, x, before=False, after=False):
|
234 |
+
assert before ^ after
|
235 |
+
if after ^ self.normalize_before:
|
236 |
+
return self.layer_norms[i](x)
|
237 |
+
else:
|
238 |
+
return x
|
239 |
+
|
240 |
+
class DecoderTransformer(nn.Module):
|
241 |
+
"""Transformer decoder."""
|
242 |
+
|
243 |
+
def __init__(self, embed_size, vocab_size, dropout=0.5, seq_length=20, num_instrs=15,
|
244 |
+
attention_nheads=16, pos_embeddings=True, num_layers=8, learned=True, normalize_before=True,
|
245 |
+
normalize_inputs=False, last_ln=False, scale_embed_grad=False):
|
246 |
+
super(DecoderTransformer, self).__init__()
|
247 |
+
self.dropout = dropout
|
248 |
+
self.seq_length = seq_length * num_instrs
|
249 |
+
self.embed_tokens = nn.Embedding(vocab_size, embed_size, padding_idx=vocab_size-1,
|
250 |
+
scale_grad_by_freq=scale_embed_grad)
|
251 |
+
nn.init.normal_(self.embed_tokens.weight, mean=0, std=embed_size ** -0.5)
|
252 |
+
if pos_embeddings:
|
253 |
+
self.embed_positions = PositionalEmbedding(1024, embed_size, 0, left_pad=False, learned=learned)
|
254 |
+
else:
|
255 |
+
self.embed_positions = None
|
256 |
+
self.normalize_inputs = normalize_inputs
|
257 |
+
if self.normalize_inputs:
|
258 |
+
self.layer_norms_in = nn.ModuleList([LayerNorm(embed_size) for i in range(3)])
|
259 |
+
|
260 |
+
self.embed_scale = math.sqrt(embed_size)
|
261 |
+
self.layers = nn.ModuleList([])
|
262 |
+
self.layers.extend([
|
263 |
+
TransformerDecoderLayer(embed_size, attention_nheads, dropout=dropout, normalize_before=normalize_before,
|
264 |
+
last_ln=last_ln)
|
265 |
+
for i in range(num_layers)
|
266 |
+
])
|
267 |
+
|
268 |
+
self.linear = Linear(embed_size, vocab_size-1)
|
269 |
+
|
270 |
+
def forward(self, ingr_features, ingr_mask, captions, img_features, incremental_state=None):
|
271 |
+
|
272 |
+
if ingr_features is not None:
|
273 |
+
ingr_features = ingr_features.permute(0, 2, 1)
|
274 |
+
ingr_features = ingr_features.transpose(0, 1)
|
275 |
+
if self.normalize_inputs:
|
276 |
+
self.layer_norms_in[0](ingr_features)
|
277 |
+
|
278 |
+
if img_features is not None:
|
279 |
+
img_features = img_features.permute(0, 2, 1)
|
280 |
+
img_features = img_features.transpose(0, 1)
|
281 |
+
if self.normalize_inputs:
|
282 |
+
self.layer_norms_in[1](img_features)
|
283 |
+
|
284 |
+
if ingr_mask is not None:
|
285 |
+
ingr_mask = (1-ingr_mask.squeeze(1)).byte()
|
286 |
+
|
287 |
+
# embed positions
|
288 |
+
if self.embed_positions is not None:
|
289 |
+
positions = self.embed_positions(captions, incremental_state=incremental_state)
|
290 |
+
if incremental_state is not None:
|
291 |
+
if self.embed_positions is not None:
|
292 |
+
positions = positions[:, -1:]
|
293 |
+
captions = captions[:, -1:]
|
294 |
+
|
295 |
+
# embed tokens and positions
|
296 |
+
x = self.embed_scale * self.embed_tokens(captions)
|
297 |
+
|
298 |
+
if self.embed_positions is not None:
|
299 |
+
x += positions
|
300 |
+
|
301 |
+
if self.normalize_inputs:
|
302 |
+
x = self.layer_norms_in[2](x)
|
303 |
+
|
304 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
305 |
+
|
306 |
+
# B x T x C -> T x B x C
|
307 |
+
x = x.transpose(0, 1)
|
308 |
+
|
309 |
+
for p, layer in enumerate(self.layers):
|
310 |
+
x = layer(
|
311 |
+
x,
|
312 |
+
ingr_features,
|
313 |
+
ingr_mask,
|
314 |
+
incremental_state,
|
315 |
+
img_features
|
316 |
+
)
|
317 |
+
|
318 |
+
# T x B x C -> B x T x C
|
319 |
+
x = x.transpose(0, 1)
|
320 |
+
|
321 |
+
x = self.linear(x)
|
322 |
+
_, predicted = x.max(dim=-1)
|
323 |
+
|
324 |
+
return x, predicted
|
325 |
+
|
326 |
+
def sample(self, ingr_features, ingr_mask, greedy=True, temperature=1.0, beam=-1,
|
327 |
+
img_features=None, first_token_value=0,
|
328 |
+
replacement=True, last_token_value=0):
|
329 |
+
|
330 |
+
incremental_state = {}
|
331 |
+
|
332 |
+
# create dummy previous word
|
333 |
+
if ingr_features is not None:
|
334 |
+
fs = ingr_features.size(0)
|
335 |
+
else:
|
336 |
+
fs = img_features.size(0)
|
337 |
+
|
338 |
+
if beam != -1:
|
339 |
+
if fs == 1:
|
340 |
+
return self.sample_beam(ingr_features, ingr_mask, beam, img_features, first_token_value,
|
341 |
+
replacement, last_token_value)
|
342 |
+
else:
|
343 |
+
print ("Beam Search can only be used with batch size of 1. Running greedy or temperature sampling...")
|
344 |
+
|
345 |
+
first_word = torch.ones(fs)*first_token_value
|
346 |
+
|
347 |
+
first_word = first_word.to(device).long()
|
348 |
+
sampled_ids = [first_word]
|
349 |
+
logits = []
|
350 |
+
|
351 |
+
for i in range(self.seq_length):
|
352 |
+
# forward
|
353 |
+
outputs, _ = self.forward(ingr_features, ingr_mask, torch.stack(sampled_ids, 1),
|
354 |
+
img_features, incremental_state)
|
355 |
+
outputs = outputs.squeeze(1)
|
356 |
+
if not replacement:
|
357 |
+
# predicted mask
|
358 |
+
if i == 0:
|
359 |
+
predicted_mask = torch.zeros(outputs.shape).float().to(device)
|
360 |
+
else:
|
361 |
+
# ensure no repetitions in sampling if replacement==False
|
362 |
+
batch_ind = [j for j in range(fs) if sampled_ids[i][j] != 0]
|
363 |
+
sampled_ids_new = sampled_ids[i][batch_ind]
|
364 |
+
predicted_mask[batch_ind, sampled_ids_new] = float('-inf')
|
365 |
+
|
366 |
+
# mask previously selected ids
|
367 |
+
outputs += predicted_mask
|
368 |
+
|
369 |
+
logits.append(outputs)
|
370 |
+
if greedy:
|
371 |
+
outputs_prob = torch.nn.functional.softmax(outputs, dim=-1)
|
372 |
+
_, predicted = outputs_prob.max(1)
|
373 |
+
predicted = predicted.detach()
|
374 |
+
else:
|
375 |
+
k = 10
|
376 |
+
outputs_prob = torch.div(outputs.squeeze(1), temperature)
|
377 |
+
outputs_prob = torch.nn.functional.softmax(outputs_prob, dim=-1).data
|
378 |
+
|
379 |
+
# top k random sampling
|
380 |
+
prob_prev_topk, indices = torch.topk(outputs_prob, k=k, dim=1)
|
381 |
+
predicted = torch.multinomial(prob_prev_topk, 1).view(-1)
|
382 |
+
predicted = torch.index_select(indices, dim=1, index=predicted)[:, 0].detach()
|
383 |
+
|
384 |
+
sampled_ids.append(predicted)
|
385 |
+
|
386 |
+
sampled_ids = torch.stack(sampled_ids[1:], 1)
|
387 |
+
logits = torch.stack(logits, 1)
|
388 |
+
|
389 |
+
return sampled_ids, logits
|
390 |
+
|
391 |
+
def sample_beam(self, ingr_features, ingr_mask, beam=3, img_features=None, first_token_value=0,
|
392 |
+
replacement=True, last_token_value=0):
|
393 |
+
k = beam
|
394 |
+
alpha = 0.0
|
395 |
+
# create dummy previous word
|
396 |
+
if ingr_features is not None:
|
397 |
+
fs = ingr_features.size(0)
|
398 |
+
else:
|
399 |
+
fs = img_features.size(0)
|
400 |
+
first_word = torch.ones(fs)*first_token_value
|
401 |
+
|
402 |
+
first_word = first_word.to(device).long()
|
403 |
+
|
404 |
+
sequences = [[[first_word], 0, {}, False, 1]]
|
405 |
+
finished = []
|
406 |
+
|
407 |
+
for i in range(self.seq_length):
|
408 |
+
# forward
|
409 |
+
all_candidates = []
|
410 |
+
for rem in range(len(sequences)):
|
411 |
+
incremental = sequences[rem][2]
|
412 |
+
outputs, _ = self.forward(ingr_features, ingr_mask, torch.stack(sequences[rem][0], 1),
|
413 |
+
img_features, incremental)
|
414 |
+
outputs = outputs.squeeze(1)
|
415 |
+
if not replacement:
|
416 |
+
# predicted mask
|
417 |
+
if i == 0:
|
418 |
+
predicted_mask = torch.zeros(outputs.shape).float().to(device)
|
419 |
+
else:
|
420 |
+
# ensure no repetitions in sampling if replacement==False
|
421 |
+
batch_ind = [j for j in range(fs) if sequences[rem][0][i][j] != 0]
|
422 |
+
sampled_ids_new = sequences[rem][0][i][batch_ind]
|
423 |
+
predicted_mask[batch_ind, sampled_ids_new] = float('-inf')
|
424 |
+
|
425 |
+
# mask previously selected ids
|
426 |
+
outputs += predicted_mask
|
427 |
+
|
428 |
+
outputs_prob = torch.nn.functional.log_softmax(outputs, dim=-1)
|
429 |
+
probs, indices = torch.topk(outputs_prob, beam)
|
430 |
+
# tokens is [batch x beam ] and every element is a list
|
431 |
+
# score is [ batch x beam ] and every element is a scalar
|
432 |
+
# incremental is [batch x beam ] and every element is a dict
|
433 |
+
|
434 |
+
|
435 |
+
for bid in range(beam):
|
436 |
+
tokens = sequences[rem][0] + [indices[:, bid]]
|
437 |
+
score = sequences[rem][1] + probs[:, bid].squeeze().item()
|
438 |
+
if indices[:,bid].item() == last_token_value:
|
439 |
+
finished.append([tokens, score, None, True, sequences[rem][-1] + 1])
|
440 |
+
else:
|
441 |
+
all_candidates.append([tokens, score, incremental, False, sequences[rem][-1] + 1])
|
442 |
+
|
443 |
+
# if all the top-k scoring beams have finished, we can return them
|
444 |
+
ordered_all = sorted(all_candidates + finished, key=lambda tup: tup[1]/(np.power(tup[-1],alpha)),
|
445 |
+
reverse=True)[:k]
|
446 |
+
if all(el[-1] == True for el in ordered_all):
|
447 |
+
all_candidates = []
|
448 |
+
|
449 |
+
# order all candidates by score
|
450 |
+
ordered = sorted(all_candidates, key=lambda tup: tup[1]/(np.power(tup[-1],alpha)), reverse=True)
|
451 |
+
# select k best
|
452 |
+
sequences = ordered[:k]
|
453 |
+
finished = sorted(finished, key=lambda tup: tup[1]/(np.power(tup[-1],alpha)), reverse=True)[:k]
|
454 |
+
|
455 |
+
if len(finished) != 0:
|
456 |
+
sampled_ids = torch.stack(finished[0][0][1:], 1)
|
457 |
+
logits = finished[0][1]
|
458 |
+
else:
|
459 |
+
sampled_ids = torch.stack(sequences[0][0][1:], 1)
|
460 |
+
logits = sequences[0][1]
|
461 |
+
return sampled_ids, logits
|
462 |
+
|
463 |
+
def max_positions(self):
|
464 |
+
"""Maximum output length supported by the decoder."""
|
465 |
+
return self.embed_positions.max_positions()
|
466 |
+
|
467 |
+
def upgrade_state_dict(self, state_dict):
|
468 |
+
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
|
469 |
+
if 'decoder.embed_positions.weights' in state_dict:
|
470 |
+
del state_dict['decoder.embed_positions.weights']
|
471 |
+
if 'decoder.embed_positions._float_tensor' not in state_dict:
|
472 |
+
state_dict['decoder.embed_positions._float_tensor'] = torch.FloatTensor()
|
473 |
+
return state_dict
|
474 |
+
|
475 |
+
|
476 |
+
|
477 |
+
def Embedding(num_embeddings, embedding_dim, padding_idx, ):
|
478 |
+
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
479 |
+
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
|
480 |
+
return m
|
481 |
+
|
482 |
+
|
483 |
+
def LayerNorm(embedding_dim):
|
484 |
+
m = nn.LayerNorm(embedding_dim)
|
485 |
+
return m
|
486 |
+
|
487 |
+
|
488 |
+
def Linear(in_features, out_features, bias=True):
|
489 |
+
m = nn.Linear(in_features, out_features, bias)
|
490 |
+
nn.init.xavier_uniform_(m.weight)
|
491 |
+
nn.init.constant_(m.bias, 0.)
|
492 |
+
return m
|
493 |
+
|
494 |
+
|
495 |
+
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad, learned=False):
|
496 |
+
if learned:
|
497 |
+
m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad)
|
498 |
+
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
|
499 |
+
nn.init.constant_(m.weight[padding_idx], 0)
|
500 |
+
else:
|
501 |
+
m = SinusoidalPositionalEmbedding(embedding_dim, padding_idx, left_pad, num_embeddings)
|
502 |
+
return m
|
src/modules/utils.py
ADDED
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
|
3 |
+
# Code adapted from https://github.com/pytorch/fairseq
|
4 |
+
# Copyright (c) 2017-present, Facebook, Inc.
|
5 |
+
# All rights reserved.
|
6 |
+
#
|
7 |
+
# This source code is licensed under the license found in the LICENSE file in
|
8 |
+
# https://github.com/pytorch/fairseq. An additional grant of patent rights
|
9 |
+
# can be found in the PATENTS file in the same directory.
|
10 |
+
|
11 |
+
from collections import defaultdict, OrderedDict
|
12 |
+
import logging
|
13 |
+
import os
|
14 |
+
import re
|
15 |
+
import torch
|
16 |
+
import traceback
|
17 |
+
|
18 |
+
from torch.serialization import default_restore_location
|
19 |
+
|
20 |
+
|
21 |
+
def torch_persistent_save(*args, **kwargs):
|
22 |
+
for i in range(3):
|
23 |
+
try:
|
24 |
+
return torch.save(*args, **kwargs)
|
25 |
+
except Exception:
|
26 |
+
if i == 2:
|
27 |
+
logging.error(traceback.format_exc())
|
28 |
+
|
29 |
+
|
30 |
+
def convert_state_dict_type(state_dict, ttype=torch.FloatTensor):
|
31 |
+
if isinstance(state_dict, dict):
|
32 |
+
cpu_dict = OrderedDict()
|
33 |
+
for k, v in state_dict.items():
|
34 |
+
cpu_dict[k] = convert_state_dict_type(v)
|
35 |
+
return cpu_dict
|
36 |
+
elif isinstance(state_dict, list):
|
37 |
+
return [convert_state_dict_type(v) for v in state_dict]
|
38 |
+
elif torch.is_tensor(state_dict):
|
39 |
+
return state_dict.type(ttype)
|
40 |
+
else:
|
41 |
+
return state_dict
|
42 |
+
|
43 |
+
|
44 |
+
def save_state(filename, args, model, criterion, optimizer, lr_scheduler,
|
45 |
+
num_updates, optim_history=None, extra_state=None):
|
46 |
+
if optim_history is None:
|
47 |
+
optim_history = []
|
48 |
+
if extra_state is None:
|
49 |
+
extra_state = {}
|
50 |
+
state_dict = {
|
51 |
+
'args': args,
|
52 |
+
'model': convert_state_dict_type(model.state_dict()),
|
53 |
+
'optimizer_history': optim_history + [
|
54 |
+
{
|
55 |
+
'criterion_name': criterion.__class__.__name__,
|
56 |
+
'optimizer_name': optimizer.__class__.__name__,
|
57 |
+
'lr_scheduler_state': lr_scheduler.state_dict(),
|
58 |
+
'num_updates': num_updates,
|
59 |
+
}
|
60 |
+
],
|
61 |
+
'last_optimizer_state': convert_state_dict_type(optimizer.state_dict()),
|
62 |
+
'extra_state': extra_state,
|
63 |
+
}
|
64 |
+
torch_persistent_save(state_dict, filename)
|
65 |
+
|
66 |
+
|
67 |
+
def load_model_state(filename, model):
|
68 |
+
if not os.path.exists(filename):
|
69 |
+
return None, [], None
|
70 |
+
state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
|
71 |
+
state = _upgrade_state_dict(state)
|
72 |
+
model.upgrade_state_dict(state['model'])
|
73 |
+
|
74 |
+
# load model parameters
|
75 |
+
try:
|
76 |
+
model.load_state_dict(state['model'], strict=True)
|
77 |
+
except Exception:
|
78 |
+
raise Exception('Cannot load model parameters from checkpoint, '
|
79 |
+
'please ensure that the architectures match')
|
80 |
+
|
81 |
+
return state['extra_state'], state['optimizer_history'], state['last_optimizer_state']
|
82 |
+
|
83 |
+
|
84 |
+
def _upgrade_state_dict(state):
|
85 |
+
"""Helper for upgrading old model checkpoints."""
|
86 |
+
# add optimizer_history
|
87 |
+
if 'optimizer_history' not in state:
|
88 |
+
state['optimizer_history'] = [
|
89 |
+
{
|
90 |
+
'criterion_name': 'CrossEntropyCriterion',
|
91 |
+
'best_loss': state['best_loss'],
|
92 |
+
},
|
93 |
+
]
|
94 |
+
state['last_optimizer_state'] = state['optimizer']
|
95 |
+
del state['optimizer']
|
96 |
+
del state['best_loss']
|
97 |
+
# move extra_state into sub-dictionary
|
98 |
+
if 'epoch' in state and 'extra_state' not in state:
|
99 |
+
state['extra_state'] = {
|
100 |
+
'epoch': state['epoch'],
|
101 |
+
'batch_offset': state['batch_offset'],
|
102 |
+
'val_loss': state['val_loss'],
|
103 |
+
}
|
104 |
+
del state['epoch']
|
105 |
+
del state['batch_offset']
|
106 |
+
del state['val_loss']
|
107 |
+
# reduce optimizer history's memory usage (only keep the last state)
|
108 |
+
if 'optimizer' in state['optimizer_history'][-1]:
|
109 |
+
state['last_optimizer_state'] = state['optimizer_history'][-1]['optimizer']
|
110 |
+
for optim_hist in state['optimizer_history']:
|
111 |
+
del optim_hist['optimizer']
|
112 |
+
# record the optimizer class name
|
113 |
+
if 'optimizer_name' not in state['optimizer_history'][-1]:
|
114 |
+
state['optimizer_history'][-1]['optimizer_name'] = 'FairseqNAG'
|
115 |
+
# move best_loss into lr_scheduler_state
|
116 |
+
if 'lr_scheduler_state' not in state['optimizer_history'][-1]:
|
117 |
+
state['optimizer_history'][-1]['lr_scheduler_state'] = {
|
118 |
+
'best': state['optimizer_history'][-1]['best_loss'],
|
119 |
+
}
|
120 |
+
del state['optimizer_history'][-1]['best_loss']
|
121 |
+
# keep track of number of updates
|
122 |
+
if 'num_updates' not in state['optimizer_history'][-1]:
|
123 |
+
state['optimizer_history'][-1]['num_updates'] = 0
|
124 |
+
# old model checkpoints may not have separate source/target positions
|
125 |
+
if hasattr(state['args'], 'max_positions') and not hasattr(state['args'], 'max_source_positions'):
|
126 |
+
state['args'].max_source_positions = state['args'].max_positions
|
127 |
+
state['args'].max_target_positions = state['args'].max_positions
|
128 |
+
# use stateful training data iterator
|
129 |
+
if 'train_iterator' not in state['extra_state']:
|
130 |
+
state['extra_state']['train_iterator'] = {
|
131 |
+
'epoch': state['extra_state']['epoch'],
|
132 |
+
'iterations_in_epoch': 0,
|
133 |
+
}
|
134 |
+
return state
|
135 |
+
|
136 |
+
|
137 |
+
def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
|
138 |
+
"""Load an ensemble of models for inference.
|
139 |
+
model_arg_overrides allows you to pass a dictionary model_arg_overrides --
|
140 |
+
{'arg_name': arg} -- to override model args that were used during model
|
141 |
+
training
|
142 |
+
"""
|
143 |
+
# load model architectures and weights
|
144 |
+
states = []
|
145 |
+
for filename in filenames:
|
146 |
+
if not os.path.exists(filename):
|
147 |
+
raise IOError('Model file not found: {}'.format(filename))
|
148 |
+
state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
|
149 |
+
state = _upgrade_state_dict(state)
|
150 |
+
states.append(state)
|
151 |
+
args = states[0]['args']
|
152 |
+
if model_arg_overrides is not None:
|
153 |
+
args = _override_model_args(args, model_arg_overrides)
|
154 |
+
|
155 |
+
# build ensemble
|
156 |
+
ensemble = []
|
157 |
+
for state in states:
|
158 |
+
model = task.build_model(args)
|
159 |
+
model.upgrade_state_dict(state['model'])
|
160 |
+
model.load_state_dict(state['model'], strict=True)
|
161 |
+
ensemble.append(model)
|
162 |
+
return ensemble, args
|
163 |
+
|
164 |
+
|
165 |
+
def _override_model_args(args, model_arg_overrides):
|
166 |
+
# Uses model_arg_overrides {'arg_name': arg} to override model args
|
167 |
+
for arg_name, arg_val in model_arg_overrides.items():
|
168 |
+
setattr(args, arg_name, arg_val)
|
169 |
+
return args
|
170 |
+
|
171 |
+
|
172 |
+
def move_to_cuda(sample):
|
173 |
+
if len(sample) == 0:
|
174 |
+
return {}
|
175 |
+
|
176 |
+
def _move_to_cuda(maybe_tensor):
|
177 |
+
if torch.is_tensor(maybe_tensor):
|
178 |
+
return maybe_tensor.cuda()
|
179 |
+
elif isinstance(maybe_tensor, dict):
|
180 |
+
return {
|
181 |
+
key: _move_to_cuda(value)
|
182 |
+
for key, value in maybe_tensor.items()
|
183 |
+
}
|
184 |
+
elif isinstance(maybe_tensor, list):
|
185 |
+
return [_move_to_cuda(x) for x in maybe_tensor]
|
186 |
+
else:
|
187 |
+
return maybe_tensor
|
188 |
+
|
189 |
+
return _move_to_cuda(sample)
|
190 |
+
|
191 |
+
|
192 |
+
INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0)
|
193 |
+
|
194 |
+
|
195 |
+
def _get_full_incremental_state_key(module_instance, key):
|
196 |
+
module_name = module_instance.__class__.__name__
|
197 |
+
|
198 |
+
# assign a unique ID to each module instance, so that incremental state is
|
199 |
+
# not shared across module instances
|
200 |
+
if not hasattr(module_instance, '_fairseq_instance_id'):
|
201 |
+
INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1
|
202 |
+
module_instance._fairseq_instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name]
|
203 |
+
|
204 |
+
return '{}.{}.{}'.format(module_name, module_instance._fairseq_instance_id, key)
|
205 |
+
|
206 |
+
|
207 |
+
def get_incremental_state(module, incremental_state, key):
|
208 |
+
"""Helper for getting incremental state for an nn.Module."""
|
209 |
+
full_key = _get_full_incremental_state_key(module, key)
|
210 |
+
if incremental_state is None or full_key not in incremental_state:
|
211 |
+
return None
|
212 |
+
return incremental_state[full_key]
|
213 |
+
|
214 |
+
|
215 |
+
def set_incremental_state(module, incremental_state, key, value):
|
216 |
+
"""Helper for setting incremental state for an nn.Module."""
|
217 |
+
if incremental_state is not None:
|
218 |
+
full_key = _get_full_incremental_state_key(module, key)
|
219 |
+
incremental_state[full_key] = value
|
220 |
+
|
221 |
+
|
222 |
+
def load_align_dict(replace_unk):
|
223 |
+
if replace_unk is None:
|
224 |
+
align_dict = None
|
225 |
+
elif isinstance(replace_unk, str):
|
226 |
+
# Load alignment dictionary for unknown word replacement if it was passed as an argument.
|
227 |
+
align_dict = {}
|
228 |
+
with open(replace_unk, 'r') as f:
|
229 |
+
for line in f:
|
230 |
+
cols = line.split()
|
231 |
+
align_dict[cols[0]] = cols[1]
|
232 |
+
else:
|
233 |
+
# No alignment dictionary provided but we still want to perform unknown word replacement by copying the
|
234 |
+
# original source word.
|
235 |
+
align_dict = {}
|
236 |
+
return align_dict
|
237 |
+
|
238 |
+
|
239 |
+
def print_embed_overlap(embed_dict, vocab_dict):
|
240 |
+
embed_keys = set(embed_dict.keys())
|
241 |
+
vocab_keys = set(vocab_dict.symbols)
|
242 |
+
overlap = len(embed_keys & vocab_keys)
|
243 |
+
print("| Found {}/{} types in embedding file.".format(overlap, len(vocab_dict)))
|
244 |
+
|
245 |
+
|
246 |
+
def parse_embedding(embed_path):
|
247 |
+
"""Parse embedding text file into a dictionary of word and embedding tensors.
|
248 |
+
The first line can have vocabulary size and dimension. The following lines
|
249 |
+
should contain word and embedding separated by spaces.
|
250 |
+
Example:
|
251 |
+
2 5
|
252 |
+
the -0.0230 -0.0264 0.0287 0.0171 0.1403
|
253 |
+
at -0.0395 -0.1286 0.0275 0.0254 -0.0932
|
254 |
+
"""
|
255 |
+
embed_dict = {}
|
256 |
+
with open(embed_path) as f_embed:
|
257 |
+
next(f_embed) # skip header
|
258 |
+
for line in f_embed:
|
259 |
+
pieces = line.rstrip().split(" ")
|
260 |
+
embed_dict[pieces[0]] = torch.Tensor([float(weight) for weight in pieces[1:]])
|
261 |
+
return embed_dict
|
262 |
+
|
263 |
+
|
264 |
+
def load_embedding(embed_dict, vocab, embedding):
|
265 |
+
for idx in range(len(vocab)):
|
266 |
+
token = vocab[idx]
|
267 |
+
if token in embed_dict:
|
268 |
+
embedding.weight.data[idx] = embed_dict[token]
|
269 |
+
return embedding
|
270 |
+
|
271 |
+
|
272 |
+
def replace_unk(hypo_str, src_str, alignment, align_dict, unk):
|
273 |
+
from fairseq import tokenizer
|
274 |
+
# Tokens are strings here
|
275 |
+
hypo_tokens = tokenizer.tokenize_line(hypo_str)
|
276 |
+
# TODO: Very rare cases where the replacement is '<eos>' should be handled gracefully
|
277 |
+
src_tokens = tokenizer.tokenize_line(src_str) + ['<eos>']
|
278 |
+
for i, ht in enumerate(hypo_tokens):
|
279 |
+
if ht == unk:
|
280 |
+
src_token = src_tokens[alignment[i]]
|
281 |
+
# Either take the corresponding value in the aligned dictionary or just copy the original value.
|
282 |
+
hypo_tokens[i] = align_dict.get(src_token, src_token)
|
283 |
+
return ' '.join(hypo_tokens)
|
284 |
+
|
285 |
+
|
286 |
+
def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, tgt_dict, remove_bpe):
|
287 |
+
from fairseq import tokenizer
|
288 |
+
hypo_str = tgt_dict.string(hypo_tokens, remove_bpe)
|
289 |
+
if align_dict is not None:
|
290 |
+
hypo_str = replace_unk(hypo_str, src_str, alignment, align_dict, tgt_dict.unk_string())
|
291 |
+
if align_dict is not None or remove_bpe is not None:
|
292 |
+
# Convert back to tokens for evaluating with unk replacement or without BPE
|
293 |
+
# Note that the dictionary can be modified inside the method.
|
294 |
+
hypo_tokens = tokenizer.Tokenizer.tokenize(hypo_str, tgt_dict, add_if_not_exist=True)
|
295 |
+
return hypo_tokens, hypo_str, alignment
|
296 |
+
|
297 |
+
|
298 |
+
def make_positions(tensor, padding_idx, left_pad):
|
299 |
+
"""Replace non-padding symbols with their position numbers.
|
300 |
+
Position numbers begin at padding_idx+1.
|
301 |
+
Padding symbols are ignored, but it is necessary to specify whether padding
|
302 |
+
is added on the left side (left_pad=True) or right side (left_pad=False).
|
303 |
+
"""
|
304 |
+
max_pos = padding_idx + 1 + tensor.size(1)
|
305 |
+
if not hasattr(make_positions, 'range_buf'):
|
306 |
+
make_positions.range_buf = tensor.new()
|
307 |
+
make_positions.range_buf = make_positions.range_buf.type_as(tensor)
|
308 |
+
if make_positions.range_buf.numel() < max_pos:
|
309 |
+
torch.arange(padding_idx + 1, max_pos, out=make_positions.range_buf)
|
310 |
+
mask = tensor.ne(padding_idx)
|
311 |
+
positions = make_positions.range_buf[:tensor.size(1)].expand_as(tensor)
|
312 |
+
if left_pad:
|
313 |
+
positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1)
|
314 |
+
return tensor.clone().masked_scatter_(mask, positions[mask])
|
315 |
+
|
316 |
+
|
317 |
+
def strip_pad(tensor, pad):
|
318 |
+
return tensor[tensor.ne(pad)]
|
319 |
+
|
320 |
+
|
321 |
+
def buffered_arange(max):
|
322 |
+
if not hasattr(buffered_arange, 'buf'):
|
323 |
+
buffered_arange.buf = torch.LongTensor()
|
324 |
+
if max > buffered_arange.buf.numel():
|
325 |
+
torch.arange(max, out=buffered_arange.buf)
|
326 |
+
return buffered_arange.buf[:max]
|
327 |
+
|
328 |
+
|
329 |
+
def convert_padding_direction(src_tokens, padding_idx, right_to_left=False, left_to_right=False):
|
330 |
+
assert right_to_left ^ left_to_right
|
331 |
+
pad_mask = src_tokens.eq(padding_idx)
|
332 |
+
if not pad_mask.any():
|
333 |
+
# no padding, return early
|
334 |
+
return src_tokens
|
335 |
+
if left_to_right and not pad_mask[:, 0].any():
|
336 |
+
# already right padded
|
337 |
+
return src_tokens
|
338 |
+
if right_to_left and not pad_mask[:, -1].any():
|
339 |
+
# already left padded
|
340 |
+
return src_tokens
|
341 |
+
max_len = src_tokens.size(1)
|
342 |
+
range = buffered_arange(max_len).type_as(src_tokens).expand_as(src_tokens)
|
343 |
+
num_pads = pad_mask.long().sum(dim=1, keepdim=True)
|
344 |
+
if right_to_left:
|
345 |
+
index = torch.remainder(range - num_pads, max_len)
|
346 |
+
else:
|
347 |
+
index = torch.remainder(range + num_pads, max_len)
|
348 |
+
return src_tokens.gather(1, index)
|
349 |
+
|
350 |
+
|
351 |
+
def item(tensor):
|
352 |
+
if hasattr(tensor, 'item'):
|
353 |
+
return tensor.item()
|
354 |
+
if hasattr(tensor, '__getitem__'):
|
355 |
+
return tensor[0]
|
356 |
+
return tensor
|
357 |
+
|
358 |
+
|
359 |
+
def clip_grad_norm_(tensor, max_norm):
|
360 |
+
grad_norm = item(torch.norm(tensor))
|
361 |
+
if grad_norm > max_norm > 0:
|
362 |
+
clip_coef = max_norm / (grad_norm + 1e-6)
|
363 |
+
tensor.mul_(clip_coef)
|
364 |
+
return grad_norm
|
365 |
+
|
366 |
+
|
367 |
+
def fill_with_neg_inf(t):
|
368 |
+
"""FP16-compatible function that fills a tensor with -inf."""
|
369 |
+
return t.float().fill_(float('-inf')).type_as(t)
|
370 |
+
|
371 |
+
|
372 |
+
def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
|
373 |
+
"""Retrieves all checkpoints found in `path` directory.
|
374 |
+
Checkpoints are identified by matching filename to the specified pattern. If
|
375 |
+
the pattern contains groups, the result will be sorted by the first group in
|
376 |
+
descending order.
|
377 |
+
"""
|
378 |
+
pt_regexp = re.compile(pattern)
|
379 |
+
files = os.listdir(path)
|
380 |
+
|
381 |
+
entries = []
|
382 |
+
for i, f in enumerate(files):
|
383 |
+
m = pt_regexp.fullmatch(f)
|
384 |
+
if m is not None:
|
385 |
+
idx = int(m.group(1)) if len(m.groups()) > 0 else i
|
386 |
+
entries.append((idx, m.group(0)))
|
387 |
+
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
|
src/read_pkl.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
|
3 |
+
with open("../data/predicted_ingr.pkl", "rb") as fp: # Unpickling
|
4 |
+
b = pickle.load(fp)
|
5 |
+
|
6 |
+
print(b)
|
7 |
+
|
src/sample.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from args import get_parser
|
6 |
+
import pickle
|
7 |
+
import os
|
8 |
+
from torchvision import transforms
|
9 |
+
from build_vocab import Vocabulary
|
10 |
+
from model import get_model
|
11 |
+
from tqdm import tqdm
|
12 |
+
from data_loader import get_loader
|
13 |
+
import json
|
14 |
+
import sys
|
15 |
+
from model import mask_from_eos
|
16 |
+
import random
|
17 |
+
from utils.metrics import softIoU, update_error_types, compute_metrics
|
18 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
19 |
+
map_loc = None if torch.cuda.is_available() else 'cpu'
|
20 |
+
|
21 |
+
|
22 |
+
def compute_score(sampled_ids):
|
23 |
+
|
24 |
+
if 1 in sampled_ids:
|
25 |
+
cut = np.where(sampled_ids == 1)[0][0]
|
26 |
+
else:
|
27 |
+
cut = -1
|
28 |
+
sampled_ids = sampled_ids[0:cut]
|
29 |
+
score = float(len(set(sampled_ids))) / float(len(sampled_ids))
|
30 |
+
|
31 |
+
return score
|
32 |
+
|
33 |
+
|
34 |
+
def label2onehot(labels, pad_value):
|
35 |
+
|
36 |
+
# input labels to one hot vector
|
37 |
+
inp_ = torch.unsqueeze(labels, 2)
|
38 |
+
one_hot = torch.FloatTensor(labels.size(0), labels.size(1), pad_value + 1).zero_().to(device)
|
39 |
+
one_hot.scatter_(2, inp_, 1)
|
40 |
+
one_hot, _ = one_hot.max(dim=1)
|
41 |
+
# remove pad and eos position
|
42 |
+
one_hot = one_hot[:, 1:-1]
|
43 |
+
one_hot[:, 0] = 0
|
44 |
+
|
45 |
+
return one_hot
|
46 |
+
|
47 |
+
|
48 |
+
def main(args):
|
49 |
+
|
50 |
+
where_to_save = os.path.join(args.save_dir, args.project_name, args.model_name)
|
51 |
+
checkpoints_dir = os.path.join(where_to_save, 'checkpoints')
|
52 |
+
logs_dir = os.path.join(where_to_save, 'logs')
|
53 |
+
|
54 |
+
if not args.log_term:
|
55 |
+
print ("Eval logs will be saved to:", os.path.join(logs_dir, 'eval.log'))
|
56 |
+
sys.stdout = open(os.path.join(logs_dir, 'eval.log'), 'w')
|
57 |
+
sys.stderr = open(os.path.join(logs_dir, 'eval.err'), 'w')
|
58 |
+
|
59 |
+
vars_to_replace = ['greedy', 'recipe_only', 'ingrs_only', 'temperature', 'batch_size', 'maxseqlen',
|
60 |
+
'get_perplexity', 'use_true_ingrs', 'eval_split', 'save_dir', 'aux_data_dir',
|
61 |
+
'recipe1m_dir', 'project_name', 'use_lmdb', 'beam']
|
62 |
+
store_dict = {}
|
63 |
+
for var in vars_to_replace:
|
64 |
+
store_dict[var] = getattr(args, var)
|
65 |
+
args = pickle.load(open(os.path.join(checkpoints_dir, 'args.pkl'), 'rb'))
|
66 |
+
for var in vars_to_replace:
|
67 |
+
setattr(args, var, store_dict[var])
|
68 |
+
print (args)
|
69 |
+
|
70 |
+
transforms_list = []
|
71 |
+
transforms_list.append(transforms.Resize((args.crop_size)))
|
72 |
+
transforms_list.append(transforms.CenterCrop(args.crop_size))
|
73 |
+
transforms_list.append(transforms.ToTensor())
|
74 |
+
transforms_list.append(transforms.Normalize((0.485, 0.456, 0.406),
|
75 |
+
(0.229, 0.224, 0.225)))
|
76 |
+
# Image preprocessing
|
77 |
+
transform = transforms.Compose(transforms_list)
|
78 |
+
|
79 |
+
# data loader
|
80 |
+
data_dir = args.recipe1m_dir
|
81 |
+
data_loader, dataset = get_loader(data_dir, args.aux_data_dir, args.eval_split,
|
82 |
+
args.maxseqlen, args.maxnuminstrs, args.maxnumlabels,
|
83 |
+
args.maxnumims, transform, args.batch_size,
|
84 |
+
shuffle=False, num_workers=args.num_workers,
|
85 |
+
drop_last=False, max_num_samples=-1,
|
86 |
+
use_lmdb=args.use_lmdb, suff=args.suff)
|
87 |
+
|
88 |
+
ingr_vocab_size = dataset.get_ingrs_vocab_size()
|
89 |
+
instrs_vocab_size = dataset.get_instrs_vocab_size()
|
90 |
+
|
91 |
+
args.numgens = 1
|
92 |
+
|
93 |
+
# Build the model
|
94 |
+
model = get_model(args, ingr_vocab_size, instrs_vocab_size)
|
95 |
+
model_path = os.path.join(args.save_dir, args.project_name, args.model_name, 'checkpoints', 'modelbest.ckpt')
|
96 |
+
|
97 |
+
# overwrite flags for inference
|
98 |
+
model.recipe_only = args.recipe_only
|
99 |
+
model.ingrs_only = args.ingrs_only
|
100 |
+
|
101 |
+
# Load the trained model parameters
|
102 |
+
model.load_state_dict(torch.load(model_path, map_location=map_loc))
|
103 |
+
|
104 |
+
model.eval()
|
105 |
+
model = model.to(device)
|
106 |
+
results_dict = {'recipes': {}, 'ingrs': {}, 'ingr_iou': {}}
|
107 |
+
captions = {}
|
108 |
+
iou = []
|
109 |
+
error_types = {'tp_i': 0, 'fp_i': 0, 'fn_i': 0, 'tn_i': 0, 'tp_all': 0, 'fp_all': 0, 'fn_all': 0}
|
110 |
+
perplexity_list = []
|
111 |
+
n_rep, th = 0, 0.3
|
112 |
+
|
113 |
+
for i, (img_inputs, true_caps_batch, ingr_gt, imgid, impath) in tqdm(enumerate(data_loader)):
|
114 |
+
|
115 |
+
ingr_gt = ingr_gt.to(device)
|
116 |
+
true_caps_batch = true_caps_batch.to(device)
|
117 |
+
|
118 |
+
true_caps_shift = true_caps_batch.clone()[:, 1:].contiguous()
|
119 |
+
img_inputs = img_inputs.to(device)
|
120 |
+
|
121 |
+
true_ingrs = ingr_gt if args.use_true_ingrs else None
|
122 |
+
for gens in range(args.numgens):
|
123 |
+
with torch.no_grad():
|
124 |
+
|
125 |
+
if args.get_perplexity:
|
126 |
+
|
127 |
+
losses = model(img_inputs, true_caps_batch, ingr_gt, keep_cnn_gradients=False)
|
128 |
+
recipe_loss = losses['recipe_loss']
|
129 |
+
recipe_loss = recipe_loss.view(true_caps_shift.size())
|
130 |
+
non_pad_mask = true_caps_shift.ne(instrs_vocab_size - 1).float()
|
131 |
+
recipe_loss = torch.sum(recipe_loss*non_pad_mask, dim=-1) / torch.sum(non_pad_mask, dim=-1)
|
132 |
+
perplexity = torch.exp(recipe_loss)
|
133 |
+
|
134 |
+
perplexity = perplexity.detach().cpu().numpy().tolist()
|
135 |
+
perplexity_list.extend(perplexity)
|
136 |
+
|
137 |
+
else:
|
138 |
+
|
139 |
+
outputs = model.sample(img_inputs, args.greedy, args.temperature, args.beam, true_ingrs)
|
140 |
+
|
141 |
+
if not args.recipe_only:
|
142 |
+
fake_ingrs = outputs['ingr_ids']
|
143 |
+
pred_one_hot = label2onehot(fake_ingrs, ingr_vocab_size - 1)
|
144 |
+
target_one_hot = label2onehot(ingr_gt, ingr_vocab_size - 1)
|
145 |
+
iou_item = torch.mean(softIoU(pred_one_hot, target_one_hot)).item()
|
146 |
+
iou.append(iou_item)
|
147 |
+
|
148 |
+
update_error_types(error_types, pred_one_hot, target_one_hot)
|
149 |
+
|
150 |
+
fake_ingrs = fake_ingrs.detach().cpu().numpy()
|
151 |
+
|
152 |
+
for ingr_idx, fake_ingr in enumerate(fake_ingrs):
|
153 |
+
|
154 |
+
iou_item = softIoU(pred_one_hot[ingr_idx].unsqueeze(0),
|
155 |
+
target_one_hot[ingr_idx].unsqueeze(0)).item()
|
156 |
+
results_dict['ingrs'][imgid[ingr_idx]] = []
|
157 |
+
results_dict['ingrs'][imgid[ingr_idx]].append(fake_ingr)
|
158 |
+
results_dict['ingr_iou'][imgid[ingr_idx]] = iou_item
|
159 |
+
|
160 |
+
if not args.ingrs_only:
|
161 |
+
sampled_ids_batch = outputs['recipe_ids']
|
162 |
+
sampled_ids_batch = sampled_ids_batch.cpu().detach().numpy()
|
163 |
+
|
164 |
+
for j, sampled_ids in enumerate(sampled_ids_batch):
|
165 |
+
score = compute_score(sampled_ids)
|
166 |
+
if score < th:
|
167 |
+
n_rep += 1
|
168 |
+
if imgid[j] not in captions.keys():
|
169 |
+
results_dict['recipes'][imgid[j]] = []
|
170 |
+
results_dict['recipes'][imgid[j]].append(sampled_ids)
|
171 |
+
if args.get_perplexity:
|
172 |
+
print (len(perplexity_list))
|
173 |
+
print (np.mean(perplexity_list))
|
174 |
+
else:
|
175 |
+
|
176 |
+
if not args.recipe_only:
|
177 |
+
ret_metrics = {'accuracy': [], 'f1': [], 'jaccard': [], 'f1_ingredients': []}
|
178 |
+
compute_metrics(ret_metrics, error_types, ['accuracy', 'f1', 'jaccard', 'f1_ingredients'],
|
179 |
+
eps=1e-10,
|
180 |
+
weights=None)
|
181 |
+
|
182 |
+
for k, v in ret_metrics.items():
|
183 |
+
print (k, np.mean(v))
|
184 |
+
|
185 |
+
if args.greedy:
|
186 |
+
suff = 'greedy'
|
187 |
+
else:
|
188 |
+
if args.beam != -1:
|
189 |
+
suff = 'beam_'+str(args.beam)
|
190 |
+
else:
|
191 |
+
suff = 'temp_' + str(args.temperature)
|
192 |
+
|
193 |
+
results_file = os.path.join(args.save_dir, args.project_name, args.model_name, 'checkpoints',
|
194 |
+
args.eval_split + '_' + suff + '_gencaps.pkl')
|
195 |
+
print (results_file)
|
196 |
+
pickle.dump(results_dict, open(results_file, 'wb'))
|
197 |
+
|
198 |
+
print ("Number of samples with excessive repetitions:", n_rep)
|
199 |
+
|
200 |
+
|
201 |
+
if __name__ == '__main__':
|
202 |
+
args = get_parser()
|
203 |
+
torch.manual_seed(1234)
|
204 |
+
torch.cuda.manual_seed(1234)
|
205 |
+
random.seed(1234)
|
206 |
+
np.random.seed(1234)
|
207 |
+
main(args)
|
src/sim_ingr.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import nltk
|
2 |
+
import pickle
|
3 |
+
import argparse
|
4 |
+
from collections import Counter
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
from tqdm import *
|
8 |
+
import numpy as np
|
9 |
+
import re
|
10 |
+
|
11 |
+
|
12 |
+
def get_ingredient(det_ingr, replace_dict):
|
13 |
+
det_ingr_undrs = det_ingr['text'].lower()
|
14 |
+
det_ingr_undrs = ''.join(i for i in det_ingr_undrs if not i.isdigit())
|
15 |
+
|
16 |
+
for rep, char_list in replace_dict.items():
|
17 |
+
for c_ in char_list:
|
18 |
+
if c_ in det_ingr_undrs:
|
19 |
+
det_ingr_undrs = det_ingr_undrs.replace(c_, rep)
|
20 |
+
det_ingr_undrs = det_ingr_undrs.strip()
|
21 |
+
det_ingr_undrs = det_ingr_undrs.replace(' ', '_')
|
22 |
+
|
23 |
+
return det_ingr_undrs
|
24 |
+
|
25 |
+
|
26 |
+
def remove_plurals(counter_ingrs, ingr_clusters):
|
27 |
+
del_ingrs = []
|
28 |
+
|
29 |
+
for k, v in counter_ingrs.items():
|
30 |
+
|
31 |
+
if len(k) == 0:
|
32 |
+
del_ingrs.append(k)
|
33 |
+
continue
|
34 |
+
|
35 |
+
gotit = 0
|
36 |
+
if k[-2:] == 'es':
|
37 |
+
if k[:-2] in counter_ingrs.keys():
|
38 |
+
counter_ingrs[k[:-2]] += v
|
39 |
+
ingr_clusters[k[:-2]].extend(ingr_clusters[k])
|
40 |
+
del_ingrs.append(k)
|
41 |
+
gotit = 1
|
42 |
+
|
43 |
+
if k[-1] == 's' and gotit == 0:
|
44 |
+
if k[:-1] in counter_ingrs.keys():
|
45 |
+
counter_ingrs[k[:-1]] += v
|
46 |
+
ingr_clusters[k[:-1]].extend(ingr_clusters[k])
|
47 |
+
del_ingrs.append(k)
|
48 |
+
for item in del_ingrs:
|
49 |
+
del counter_ingrs[item]
|
50 |
+
del ingr_clusters[item]
|
51 |
+
return counter_ingrs, ingr_clusters
|
52 |
+
|
53 |
+
|
54 |
+
def cluster_ingredients(counter_ingrs):
|
55 |
+
mydict = dict()
|
56 |
+
mydict_ingrs = dict()
|
57 |
+
|
58 |
+
for k, v in counter_ingrs.items():
|
59 |
+
|
60 |
+
w1 = k.split('_')[-1]
|
61 |
+
w2 = k.split('_')[0]
|
62 |
+
lw = [w1, w2]
|
63 |
+
if len(k.split('_')) > 1:
|
64 |
+
w3 = k.split('_')[0] + '_' + k.split('_')[1]
|
65 |
+
w4 = k.split('_')[-2] + '_' + k.split('_')[-1]
|
66 |
+
|
67 |
+
lw = [w1, w2, w4, w3]
|
68 |
+
|
69 |
+
gotit = 0
|
70 |
+
for w in lw:
|
71 |
+
if w in counter_ingrs.keys():
|
72 |
+
# check if its parts are
|
73 |
+
parts = w.split('_')
|
74 |
+
if len(parts) > 0:
|
75 |
+
if parts[0] in counter_ingrs.keys():
|
76 |
+
w = parts[0]
|
77 |
+
elif parts[1] in counter_ingrs.keys():
|
78 |
+
w = parts[1]
|
79 |
+
if w in mydict.keys():
|
80 |
+
mydict[w] += v
|
81 |
+
mydict_ingrs[w].append(k)
|
82 |
+
else:
|
83 |
+
mydict[w] = v
|
84 |
+
mydict_ingrs[w] = [k]
|
85 |
+
gotit = 1
|
86 |
+
break
|
87 |
+
if gotit == 0:
|
88 |
+
mydict[k] = v
|
89 |
+
mydict_ingrs[k] = [k]
|
90 |
+
|
91 |
+
return mydict, mydict_ingrs
|
92 |
+
|
93 |
+
|
94 |
+
def update_counter(list_, counter_toks, istrain=False):
|
95 |
+
for sentence in list_:
|
96 |
+
tokens = nltk.tokenize.word_tokenize(sentence)
|
97 |
+
if istrain:
|
98 |
+
counter_toks.update(tokens)
|
99 |
+
|
100 |
+
|
101 |
+
def build_vocab_recipe1m(args):
|
102 |
+
print ("Loading data...")
|
103 |
+
dets = json.load(open(os.path.join(args.recipe1m_path, 'det_ingrs.json'), 'r'))
|
104 |
+
|
105 |
+
replace_dict_ingrs = {'and': ['&', "'n"], '': ['%', ',', '.', '#', '[', ']', '!', '?']}
|
106 |
+
replace_dict_instrs = {'and': ['&', "'n"], '': ['#', '[', ']']}
|
107 |
+
|
108 |
+
idx2ind = {}
|
109 |
+
for i, entry in enumerate(dets):
|
110 |
+
idx2ind[entry['id']] = i
|
111 |
+
|
112 |
+
ingrs_file = args.save_path + 'allingrs_count.pkl'
|
113 |
+
instrs_file = args.save_path + 'allwords_count.pkl'
|
114 |
+
|
115 |
+
# manually add missing entries for better clustering
|
116 |
+
base_words = ['peppers', 'tomato', 'spinach_leaves', 'turkey_breast', 'lettuce_leaf',
|
117 |
+
'chicken_thighs', 'milk_powder', 'bread_crumbs', 'onion_flakes',
|
118 |
+
'red_pepper', 'pepper_flakes', 'juice_concentrate', 'cracker_crumbs', 'hot_chili',
|
119 |
+
'seasoning_mix', 'dill_weed', 'pepper_sauce', 'sprouts', 'cooking_spray', 'cheese_blend',
|
120 |
+
'basil_leaves', 'pineapple_chunks', 'marshmallow', 'chile_powder',
|
121 |
+
'cheese_blend', 'corn_kernels', 'tomato_sauce', 'chickens', 'cracker_crust',
|
122 |
+
'lemonade_concentrate', 'red_chili', 'mushroom_caps', 'mushroom_cap', 'breaded_chicken',
|
123 |
+
'frozen_pineapple', 'pineapple_chunks', 'seasoning_mix', 'seaweed', 'onion_flakes',
|
124 |
+
'bouillon_granules', 'lettuce_leaf', 'stuffing_mix', 'parsley_flakes', 'chicken_breast',
|
125 |
+
'basil_leaves', 'baguettes', 'green_tea', 'peanut_butter', 'green_onion', 'fresh_cilantro',
|
126 |
+
'breaded_chicken', 'hot_pepper', 'dried_lavender', 'white_chocolate',
|
127 |
+
'dill_weed', 'cake_mix', 'cheese_spread', 'turkey_breast', 'chucken_thighs', 'basil_leaves',
|
128 |
+
'mandarin_orange', 'laurel', 'cabbage_head', 'pistachio', 'cheese_dip',
|
129 |
+
'thyme_leave', 'boneless_pork', 'red_pepper', 'onion_dip', 'skinless_chicken', 'dark_chocolate',
|
130 |
+
'canned_corn', 'muffin', 'cracker_crust', 'bread_crumbs', 'frozen_broccoli',
|
131 |
+
'philadelphia', 'cracker_crust', 'chicken_breast']
|
132 |
+
|
133 |
+
for base_word in base_words:
|
134 |
+
|
135 |
+
if base_word not in counter_ingrs.keys():
|
136 |
+
counter_ingrs[base_word] = 1
|
137 |
+
|
138 |
+
counter_ingrs, cluster_ingrs = cluster_ingredients(counter_ingrs)
|
139 |
+
counter_ingrs, cluster_ingrs = remove_plurals(counter_ingrs, cluster_ingrs)
|
140 |
+
|
141 |
+
# If the word frequency is less than 'threshold', then the word is discarded.
|
142 |
+
words = [word for word, cnt in counter_toks.items() if cnt >= args.threshold_words]
|
143 |
+
ingrs = {word: cnt for word, cnt in counter_ingrs.items() if cnt >= args.threshold_ingrs}
|
144 |
+
|
145 |
+
|
146 |
+
def main(args):
|
147 |
+
|
148 |
+
vocab_ingrs, vocab_toks, dataset = build_vocab_recipe1m(args)
|
149 |
+
|
150 |
+
with open(os.path.join(args.save_path, args.suff+'recipe1m_vocab_ingrs.pkl'), 'wb') as f:
|
151 |
+
pickle.dump(vocab_ingrs, f)
|
152 |
+
with open(os.path.join(args.save_path, args.suff+'recipe1m_vocab_toks.pkl'), 'wb') as f:
|
153 |
+
pickle.dump(vocab_toks, f)
|
154 |
+
|
155 |
+
for split in dataset.keys():
|
156 |
+
with open(os.path.join(args.save_path, args.suff+'recipe1m_' + split + '.pkl'), 'wb') as f:
|
157 |
+
pickle.dump(dataset[split], f)
|
158 |
+
|
159 |
+
|
160 |
+
if __name__ == '__main__':
|
161 |
+
|
162 |
+
parser = argparse.ArgumentParser()
|
163 |
+
parser.add_argument('--recipe1m_path', type=str,
|
164 |
+
default='path/to/recipe1m',
|
165 |
+
help='recipe1m path')
|
166 |
+
|
167 |
+
parser.add_argument('--save_path', type=str, default='../data/',
|
168 |
+
help='path for saving vocabulary wrapper')
|
169 |
+
|
170 |
+
parser.add_argument('--suff', type=str, default='')
|
171 |
+
|
172 |
+
parser.add_argument('--threshold_ingrs', type=int, default=10,
|
173 |
+
help='minimum ingr count threshold')
|
174 |
+
|
175 |
+
parser.add_argument('--threshold_words', type=int, default=10,
|
176 |
+
help='minimum word count threshold')
|
177 |
+
|
178 |
+
parser.add_argument('--maxnuminstrs', type=int, default=20,
|
179 |
+
help='max number of instructions (sentences)')
|
180 |
+
|
181 |
+
parser.add_argument('--maxnumingrs', type=int, default=20,
|
182 |
+
help='max number of ingredients')
|
183 |
+
|
184 |
+
parser.add_argument('--minnuminstrs', type=int, default=2,
|
185 |
+
help='max number of instructions (sentences)')
|
186 |
+
|
187 |
+
parser.add_argument('--minnumingrs', type=int, default=2,
|
188 |
+
help='max number of ingredients')
|
189 |
+
|
190 |
+
parser.add_argument('--minnumwords', type=int, default=20,
|
191 |
+
help='minimum number of characters in recipe')
|
192 |
+
|
193 |
+
parser.add_argument('--forcegen', dest='forcegen', action='store_true')
|
194 |
+
parser.set_defaults(forcegen=False)
|
195 |
+
|
196 |
+
args = parser.parse_args()
|
197 |
+
main(args)
|
src/train.py
ADDED
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
|
3 |
+
from args import get_parser
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.autograd as autograd
|
7 |
+
import numpy as np
|
8 |
+
import os
|
9 |
+
import random
|
10 |
+
import pickle
|
11 |
+
from data_loader import get_loader
|
12 |
+
from build_vocab import Vocabulary
|
13 |
+
from model import get_model
|
14 |
+
from torchvision import transforms
|
15 |
+
import sys
|
16 |
+
import json
|
17 |
+
import time
|
18 |
+
import torch.backends.cudnn as cudnn
|
19 |
+
from utils.tb_visualizer import Visualizer
|
20 |
+
from model import mask_from_eos, label2onehot
|
21 |
+
from utils.metrics import softIoU, compute_metrics, update_error_types
|
22 |
+
import random
|
23 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
24 |
+
map_loc = None if torch.cuda.is_available() else 'cpu'
|
25 |
+
|
26 |
+
|
27 |
+
def merge_models(args, model, ingr_vocab_size, instrs_vocab_size):
|
28 |
+
load_args = pickle.load(open(os.path.join(args.save_dir, args.project_name,
|
29 |
+
args.transfer_from, 'checkpoints/args.pkl'), 'rb'))
|
30 |
+
|
31 |
+
model_ingrs = get_model(load_args, ingr_vocab_size, instrs_vocab_size)
|
32 |
+
model_path = os.path.join(args.save_dir, args.project_name, args.transfer_from, 'checkpoints', 'modelbest.ckpt')
|
33 |
+
|
34 |
+
# Load the trained model parameters
|
35 |
+
model_ingrs.load_state_dict(torch.load(model_path, map_location=map_loc))
|
36 |
+
model.ingredient_decoder = model_ingrs.ingredient_decoder
|
37 |
+
args.transf_layers_ingrs = load_args.transf_layers_ingrs
|
38 |
+
args.n_att_ingrs = load_args.n_att_ingrs
|
39 |
+
|
40 |
+
return args, model
|
41 |
+
|
42 |
+
|
43 |
+
def save_model(model, optimizer, checkpoints_dir, suff=''):
|
44 |
+
if torch.cuda.device_count() > 1:
|
45 |
+
torch.save(model.module.state_dict(), os.path.join(
|
46 |
+
checkpoints_dir, 'model' + suff + '.ckpt'))
|
47 |
+
|
48 |
+
else:
|
49 |
+
torch.save(model.state_dict(), os.path.join(
|
50 |
+
checkpoints_dir, 'model' + suff + '.ckpt'))
|
51 |
+
|
52 |
+
torch.save(optimizer.state_dict(), os.path.join(
|
53 |
+
checkpoints_dir, 'optim' + suff + '.ckpt'))
|
54 |
+
|
55 |
+
|
56 |
+
def count_parameters(model):
|
57 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
58 |
+
|
59 |
+
|
60 |
+
def set_lr(optimizer, decay_factor):
|
61 |
+
for group in optimizer.param_groups:
|
62 |
+
group['lr'] = group['lr']*decay_factor
|
63 |
+
|
64 |
+
|
65 |
+
def make_dir(d):
|
66 |
+
if not os.path.exists(d):
|
67 |
+
os.makedirs(d)
|
68 |
+
|
69 |
+
|
70 |
+
def main(args):
|
71 |
+
|
72 |
+
# Create model directory & other aux folders for logging
|
73 |
+
where_to_save = os.path.join(args.save_dir, args.project_name, args.model_name)
|
74 |
+
checkpoints_dir = os.path.join(where_to_save, 'checkpoints')
|
75 |
+
logs_dir = os.path.join(where_to_save, 'logs')
|
76 |
+
tb_logs = os.path.join(args.save_dir, args.project_name, 'tb_logs', args.model_name)
|
77 |
+
make_dir(where_to_save)
|
78 |
+
make_dir(logs_dir)
|
79 |
+
make_dir(checkpoints_dir)
|
80 |
+
make_dir(tb_logs)
|
81 |
+
if args.tensorboard:
|
82 |
+
logger = Visualizer(tb_logs, name='visual_results')
|
83 |
+
|
84 |
+
# check if we want to resume from last checkpoint of current model
|
85 |
+
if args.resume:
|
86 |
+
args = pickle.load(open(os.path.join(checkpoints_dir, 'args.pkl'), 'rb'))
|
87 |
+
args.resume = True
|
88 |
+
|
89 |
+
# logs to disk
|
90 |
+
if not args.log_term:
|
91 |
+
print ("Training logs will be saved to:", os.path.join(logs_dir, 'train.log'))
|
92 |
+
sys.stdout = open(os.path.join(logs_dir, 'train.log'), 'w')
|
93 |
+
sys.stderr = open(os.path.join(logs_dir, 'train.err'), 'w')
|
94 |
+
|
95 |
+
print(args)
|
96 |
+
pickle.dump(args, open(os.path.join(checkpoints_dir, 'args.pkl'), 'wb'))
|
97 |
+
|
98 |
+
# patience init
|
99 |
+
curr_pat = 0
|
100 |
+
|
101 |
+
# Build data loader
|
102 |
+
data_loaders = {}
|
103 |
+
datasets = {}
|
104 |
+
|
105 |
+
data_dir = args.recipe1m_dir
|
106 |
+
for split in ['train', 'val']:
|
107 |
+
|
108 |
+
transforms_list = [transforms.Resize((args.image_size))]
|
109 |
+
|
110 |
+
if split == 'train':
|
111 |
+
# Image preprocessing, normalization for the pretrained resnet
|
112 |
+
transforms_list.append(transforms.RandomHorizontalFlip())
|
113 |
+
transforms_list.append(transforms.RandomAffine(degrees=10, translate=(0.1, 0.1)))
|
114 |
+
transforms_list.append(transforms.RandomCrop(args.crop_size))
|
115 |
+
|
116 |
+
else:
|
117 |
+
transforms_list.append(transforms.CenterCrop(args.crop_size))
|
118 |
+
transforms_list.append(transforms.ToTensor())
|
119 |
+
transforms_list.append(transforms.Normalize((0.485, 0.456, 0.406),
|
120 |
+
(0.229, 0.224, 0.225)))
|
121 |
+
|
122 |
+
transform = transforms.Compose(transforms_list)
|
123 |
+
max_num_samples = max(args.max_eval, args.batch_size) if split == 'val' else -1
|
124 |
+
data_loaders[split], datasets[split] = get_loader(data_dir, args.aux_data_dir, split,
|
125 |
+
args.maxseqlen,
|
126 |
+
args.maxnuminstrs,
|
127 |
+
args.maxnumlabels,
|
128 |
+
args.maxnumims,
|
129 |
+
transform, args.batch_size,
|
130 |
+
shuffle=split == 'train', num_workers=args.num_workers,
|
131 |
+
drop_last=True,
|
132 |
+
max_num_samples=max_num_samples,
|
133 |
+
use_lmdb=args.use_lmdb,
|
134 |
+
suff=args.suff)
|
135 |
+
|
136 |
+
ingr_vocab_size = datasets[split].get_ingrs_vocab_size()
|
137 |
+
instrs_vocab_size = datasets[split].get_instrs_vocab_size()
|
138 |
+
|
139 |
+
# Build the model
|
140 |
+
model = get_model(args, ingr_vocab_size, instrs_vocab_size)
|
141 |
+
keep_cnn_gradients = False
|
142 |
+
|
143 |
+
decay_factor = 1.0
|
144 |
+
|
145 |
+
# add model parameters
|
146 |
+
if args.ingrs_only:
|
147 |
+
params = list(model.ingredient_decoder.parameters())
|
148 |
+
elif args.recipe_only:
|
149 |
+
params = list(model.recipe_decoder.parameters()) + list(model.ingredient_encoder.parameters())
|
150 |
+
else:
|
151 |
+
params = list(model.recipe_decoder.parameters()) + list(model.ingredient_decoder.parameters()) \
|
152 |
+
+ list(model.ingredient_encoder.parameters())
|
153 |
+
|
154 |
+
# only train the linear layer in the encoder if we are not transfering from another model
|
155 |
+
if args.transfer_from == '':
|
156 |
+
params += list(model.image_encoder.linear.parameters())
|
157 |
+
params_cnn = list(model.image_encoder.resnet.parameters())
|
158 |
+
|
159 |
+
print ("CNN params:", sum(p.numel() for p in params_cnn if p.requires_grad))
|
160 |
+
print ("decoder params:", sum(p.numel() for p in params if p.requires_grad))
|
161 |
+
# start optimizing cnn from the beginning
|
162 |
+
if params_cnn is not None and args.finetune_after == 0:
|
163 |
+
optimizer = torch.optim.Adam([{'params': params}, {'params': params_cnn,
|
164 |
+
'lr': args.learning_rate*args.scale_learning_rate_cnn}],
|
165 |
+
lr=args.learning_rate, weight_decay=args.weight_decay)
|
166 |
+
keep_cnn_gradients = True
|
167 |
+
print ("Fine tuning resnet")
|
168 |
+
else:
|
169 |
+
optimizer = torch.optim.Adam(params, lr=args.learning_rate)
|
170 |
+
|
171 |
+
if args.resume:
|
172 |
+
model_path = os.path.join(args.save_dir, args.project_name, args.model_name, 'checkpoints', 'model.ckpt')
|
173 |
+
optim_path = os.path.join(args.save_dir, args.project_name, args.model_name, 'checkpoints', 'optim.ckpt')
|
174 |
+
optimizer.load_state_dict(torch.load(optim_path, map_location=map_loc))
|
175 |
+
for state in optimizer.state.values():
|
176 |
+
for k, v in state.items():
|
177 |
+
if isinstance(v, torch.Tensor):
|
178 |
+
state[k] = v.to(device)
|
179 |
+
model.load_state_dict(torch.load(model_path, map_location=map_loc))
|
180 |
+
|
181 |
+
if args.transfer_from != '':
|
182 |
+
# loads CNN encoder from transfer_from model
|
183 |
+
model_path = os.path.join(args.save_dir, args.project_name, args.transfer_from, 'checkpoints', 'modelbest.ckpt')
|
184 |
+
pretrained_dict = torch.load(model_path, map_location=map_loc)
|
185 |
+
pretrained_dict = {k: v for k, v in pretrained_dict.items() if 'encoder' in k}
|
186 |
+
model.load_state_dict(pretrained_dict, strict=False)
|
187 |
+
args, model = merge_models(args, model, ingr_vocab_size, instrs_vocab_size)
|
188 |
+
|
189 |
+
if device != 'cpu' and torch.cuda.device_count() > 1:
|
190 |
+
model = nn.DataParallel(model)
|
191 |
+
|
192 |
+
model = model.to(device)
|
193 |
+
cudnn.benchmark = True
|
194 |
+
|
195 |
+
if not hasattr(args, 'current_epoch'):
|
196 |
+
args.current_epoch = 0
|
197 |
+
|
198 |
+
es_best = 10000 if args.es_metric == 'loss' else 0
|
199 |
+
# Train the model
|
200 |
+
start = args.current_epoch
|
201 |
+
for epoch in range(start, args.num_epochs):
|
202 |
+
|
203 |
+
# save current epoch for resuming
|
204 |
+
if args.tensorboard:
|
205 |
+
logger.reset()
|
206 |
+
|
207 |
+
args.current_epoch = epoch
|
208 |
+
# increase / decrase values for moving params
|
209 |
+
if args.decay_lr:
|
210 |
+
frac = epoch // args.lr_decay_every
|
211 |
+
decay_factor = args.lr_decay_rate ** frac
|
212 |
+
new_lr = args.learning_rate*decay_factor
|
213 |
+
print ('Epoch %d. lr: %.5f'%(epoch, new_lr))
|
214 |
+
set_lr(optimizer, decay_factor)
|
215 |
+
|
216 |
+
if args.finetune_after != -1 and args.finetune_after < epoch \
|
217 |
+
and not keep_cnn_gradients and params_cnn is not None:
|
218 |
+
|
219 |
+
print("Starting to fine tune CNN")
|
220 |
+
# start with learning rates as they were (if decayed during training)
|
221 |
+
optimizer = torch.optim.Adam([{'params': params},
|
222 |
+
{'params': params_cnn,
|
223 |
+
'lr': decay_factor*args.learning_rate*args.scale_learning_rate_cnn}],
|
224 |
+
lr=decay_factor*args.learning_rate)
|
225 |
+
keep_cnn_gradients = True
|
226 |
+
|
227 |
+
for split in ['train', 'val']:
|
228 |
+
|
229 |
+
if split == 'train':
|
230 |
+
model.train()
|
231 |
+
else:
|
232 |
+
model.eval()
|
233 |
+
total_step = len(data_loaders[split])
|
234 |
+
loader = iter(data_loaders[split])
|
235 |
+
|
236 |
+
total_loss_dict = {'recipe_loss': [], 'ingr_loss': [],
|
237 |
+
'eos_loss': [], 'loss': [],
|
238 |
+
'iou': [], 'perplexity': [], 'iou_sample': [],
|
239 |
+
'f1': [],
|
240 |
+
'card_penalty': []}
|
241 |
+
|
242 |
+
error_types = {'tp_i': 0, 'fp_i': 0, 'fn_i': 0, 'tn_i': 0,
|
243 |
+
'tp_all': 0, 'fp_all': 0, 'fn_all': 0}
|
244 |
+
|
245 |
+
torch.cuda.synchronize()
|
246 |
+
start = time.time()
|
247 |
+
|
248 |
+
for i in range(total_step):
|
249 |
+
|
250 |
+
img_inputs, captions, ingr_gt, img_ids, paths = loader.next()
|
251 |
+
|
252 |
+
ingr_gt = ingr_gt.to(device)
|
253 |
+
img_inputs = img_inputs.to(device)
|
254 |
+
captions = captions.to(device)
|
255 |
+
true_caps_batch = captions.clone()[:, 1:].contiguous()
|
256 |
+
loss_dict = {}
|
257 |
+
|
258 |
+
if split == 'val':
|
259 |
+
with torch.no_grad():
|
260 |
+
losses = model(img_inputs, captions, ingr_gt)
|
261 |
+
|
262 |
+
if not args.recipe_only:
|
263 |
+
outputs = model(img_inputs, captions, ingr_gt, sample=True)
|
264 |
+
|
265 |
+
ingr_ids_greedy = outputs['ingr_ids']
|
266 |
+
|
267 |
+
mask = mask_from_eos(ingr_ids_greedy, eos_value=0, mult_before=False)
|
268 |
+
ingr_ids_greedy[mask == 0] = ingr_vocab_size-1
|
269 |
+
pred_one_hot = label2onehot(ingr_ids_greedy, ingr_vocab_size-1)
|
270 |
+
target_one_hot = label2onehot(ingr_gt, ingr_vocab_size-1)
|
271 |
+
iou_sample = softIoU(pred_one_hot, target_one_hot)
|
272 |
+
iou_sample = iou_sample.sum() / (torch.nonzero(iou_sample.data).size(0) + 1e-6)
|
273 |
+
loss_dict['iou_sample'] = iou_sample.item()
|
274 |
+
|
275 |
+
update_error_types(error_types, pred_one_hot, target_one_hot)
|
276 |
+
|
277 |
+
del outputs, pred_one_hot, target_one_hot, iou_sample
|
278 |
+
|
279 |
+
else:
|
280 |
+
losses = model(img_inputs, captions, ingr_gt,
|
281 |
+
keep_cnn_gradients=keep_cnn_gradients)
|
282 |
+
|
283 |
+
if not args.ingrs_only:
|
284 |
+
recipe_loss = losses['recipe_loss']
|
285 |
+
|
286 |
+
recipe_loss = recipe_loss.view(true_caps_batch.size())
|
287 |
+
non_pad_mask = true_caps_batch.ne(instrs_vocab_size - 1).float()
|
288 |
+
|
289 |
+
recipe_loss = torch.sum(recipe_loss*non_pad_mask, dim=-1) / torch.sum(non_pad_mask, dim=-1)
|
290 |
+
perplexity = torch.exp(recipe_loss)
|
291 |
+
|
292 |
+
recipe_loss = recipe_loss.mean()
|
293 |
+
perplexity = perplexity.mean()
|
294 |
+
|
295 |
+
loss_dict['recipe_loss'] = recipe_loss.item()
|
296 |
+
loss_dict['perplexity'] = perplexity.item()
|
297 |
+
else:
|
298 |
+
recipe_loss = 0
|
299 |
+
|
300 |
+
if not args.recipe_only:
|
301 |
+
|
302 |
+
ingr_loss = losses['ingr_loss']
|
303 |
+
ingr_loss = ingr_loss.mean()
|
304 |
+
loss_dict['ingr_loss'] = ingr_loss.item()
|
305 |
+
|
306 |
+
eos_loss = losses['eos_loss']
|
307 |
+
eos_loss = eos_loss.mean()
|
308 |
+
loss_dict['eos_loss'] = eos_loss.item()
|
309 |
+
|
310 |
+
iou_seq = losses['iou']
|
311 |
+
iou_seq = iou_seq.mean()
|
312 |
+
loss_dict['iou'] = iou_seq.item()
|
313 |
+
|
314 |
+
card_penalty = losses['card_penalty'].mean()
|
315 |
+
loss_dict['card_penalty'] = card_penalty.item()
|
316 |
+
else:
|
317 |
+
ingr_loss, eos_loss, card_penalty = 0, 0, 0
|
318 |
+
|
319 |
+
loss = args.loss_weight[0] * recipe_loss + args.loss_weight[1] * ingr_loss \
|
320 |
+
+ args.loss_weight[2]*eos_loss + args.loss_weight[3]*card_penalty
|
321 |
+
|
322 |
+
loss_dict['loss'] = loss.item()
|
323 |
+
|
324 |
+
for key in loss_dict.keys():
|
325 |
+
total_loss_dict[key].append(loss_dict[key])
|
326 |
+
|
327 |
+
if split == 'train':
|
328 |
+
model.zero_grad()
|
329 |
+
loss.backward()
|
330 |
+
optimizer.step()
|
331 |
+
|
332 |
+
# Print log info
|
333 |
+
if args.log_step != -1 and i % args.log_step == 0:
|
334 |
+
elapsed_time = time.time()-start
|
335 |
+
lossesstr = ""
|
336 |
+
for k in total_loss_dict.keys():
|
337 |
+
if len(total_loss_dict[k]) == 0:
|
338 |
+
continue
|
339 |
+
this_one = "%s: %.4f" % (k, np.mean(total_loss_dict[k][-args.log_step:]))
|
340 |
+
lossesstr += this_one + ', '
|
341 |
+
# this only displays nll loss on captions, the rest of losses will be in tensorboard logs
|
342 |
+
strtoprint = 'Split: %s, Epoch [%d/%d], Step [%d/%d], Losses: %sTime: %.4f' % (split, epoch,
|
343 |
+
args.num_epochs, i,
|
344 |
+
total_step,
|
345 |
+
lossesstr,
|
346 |
+
elapsed_time)
|
347 |
+
print(strtoprint)
|
348 |
+
|
349 |
+
if args.tensorboard:
|
350 |
+
# logger.histo_summary(model=model, step=total_step * epoch + i)
|
351 |
+
logger.scalar_summary(mode=split+'_iter', epoch=total_step*epoch+i,
|
352 |
+
**{k: np.mean(v[-args.log_step:]) for k, v in total_loss_dict.items() if v})
|
353 |
+
|
354 |
+
torch.cuda.synchronize()
|
355 |
+
start = time.time()
|
356 |
+
del loss, losses, captions, img_inputs
|
357 |
+
|
358 |
+
if split == 'val' and not args.recipe_only:
|
359 |
+
ret_metrics = {'accuracy': [], 'f1': [], 'jaccard': [], 'f1_ingredients': [], 'dice': []}
|
360 |
+
compute_metrics(ret_metrics, error_types,
|
361 |
+
['accuracy', 'f1', 'jaccard', 'f1_ingredients', 'dice'], eps=1e-10,
|
362 |
+
weights=None)
|
363 |
+
|
364 |
+
total_loss_dict['f1'] = ret_metrics['f1']
|
365 |
+
if args.tensorboard:
|
366 |
+
# 1. Log scalar values (scalar summary)
|
367 |
+
logger.scalar_summary(mode=split,
|
368 |
+
epoch=epoch,
|
369 |
+
**{k: np.mean(v) for k, v in total_loss_dict.items() if v})
|
370 |
+
|
371 |
+
# Save the model's best checkpoint if performance was improved
|
372 |
+
es_value = np.mean(total_loss_dict[args.es_metric])
|
373 |
+
|
374 |
+
# save current model as well
|
375 |
+
save_model(model, optimizer, checkpoints_dir, suff='')
|
376 |
+
if (args.es_metric == 'loss' and es_value < es_best) or (args.es_metric == 'iou_sample' and es_value > es_best):
|
377 |
+
es_best = es_value
|
378 |
+
save_model(model, optimizer, checkpoints_dir, suff='best')
|
379 |
+
pickle.dump(args, open(os.path.join(checkpoints_dir, 'args.pkl'), 'wb'))
|
380 |
+
curr_pat = 0
|
381 |
+
print('Saved checkpoint.')
|
382 |
+
else:
|
383 |
+
curr_pat += 1
|
384 |
+
|
385 |
+
if curr_pat > args.patience:
|
386 |
+
break
|
387 |
+
|
388 |
+
if args.tensorboard:
|
389 |
+
logger.close()
|
390 |
+
|
391 |
+
|
392 |
+
if __name__ == '__main__':
|
393 |
+
args = get_parser()
|
394 |
+
torch.manual_seed(1234)
|
395 |
+
torch.cuda.manual_seed(1234)
|
396 |
+
random.seed(1234)
|
397 |
+
np.random.seed(1234)
|
398 |
+
main(args)
|
src/utils/ims2file.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
|
3 |
+
import pickle
|
4 |
+
from tqdm import tqdm
|
5 |
+
import os
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
import argparse
|
9 |
+
import lmdb
|
10 |
+
from torchvision import transforms
|
11 |
+
|
12 |
+
|
13 |
+
MAX_SIZE = 1e12
|
14 |
+
|
15 |
+
|
16 |
+
def load_and_resize(root, path, imscale):
|
17 |
+
|
18 |
+
transf_list = []
|
19 |
+
transf_list.append(transforms.Resize(imscale))
|
20 |
+
transf_list.append(transforms.CenterCrop(imscale))
|
21 |
+
transform = transforms.Compose(transf_list)
|
22 |
+
|
23 |
+
img = Image.open(os.path.join(root, path[0], path[1], path[2], path[3], path)).convert('RGB')
|
24 |
+
img = transform(img)
|
25 |
+
|
26 |
+
return img
|
27 |
+
|
28 |
+
|
29 |
+
def main(args):
|
30 |
+
|
31 |
+
parts = {}
|
32 |
+
datasets = {}
|
33 |
+
imname2pos = {'train': {}, 'val': {}, 'test': {}}
|
34 |
+
for split in ['train', 'val', 'test']:
|
35 |
+
datasets[split] = pickle.load(open(os.path.join(args.save_dir, args.suff + 'recipe1m_' + split + '.pkl'), 'rb'))
|
36 |
+
|
37 |
+
parts[split] = lmdb.open(os.path.join(args.save_dir, 'lmdb_'+split), map_size=int(MAX_SIZE))
|
38 |
+
with parts[split].begin() as txn:
|
39 |
+
present_entries = [key for key, _ in txn.cursor()]
|
40 |
+
j = 0
|
41 |
+
for i, entry in tqdm(enumerate(datasets[split])):
|
42 |
+
impaths = entry['images'][0:5]
|
43 |
+
|
44 |
+
for n, p in enumerate(impaths):
|
45 |
+
if n == args.maxnumims:
|
46 |
+
break
|
47 |
+
if p.encode() not in present_entries:
|
48 |
+
im = load_and_resize(os.path.join(args.root, 'images', split), p, args.imscale)
|
49 |
+
im = np.array(im).astype(np.uint8)
|
50 |
+
with parts[split].begin(write=True) as txn:
|
51 |
+
txn.put(p.encode(), im)
|
52 |
+
imname2pos[split][p] = j
|
53 |
+
j += 1
|
54 |
+
pickle.dump(imname2pos, open(os.path.join(args.save_dir, 'imname2pos.pkl'), 'wb'))
|
55 |
+
|
56 |
+
|
57 |
+
def test(args):
|
58 |
+
|
59 |
+
imname2pos = pickle.load(open(os.path.join(args.save_dir, 'imname2pos.pkl'), 'rb'))
|
60 |
+
paths = imname2pos['val']
|
61 |
+
|
62 |
+
for k, v in paths.items():
|
63 |
+
path = k
|
64 |
+
break
|
65 |
+
image_file = lmdb.open(os.path.join(args.save_dir, 'lmdb_' + 'val'), max_readers=1, readonly=True,
|
66 |
+
lock=False, readahead=False, meminit=False)
|
67 |
+
with image_file.begin(write=False) as txn:
|
68 |
+
image = txn.get(path.encode())
|
69 |
+
image = np.fromstring(image, dtype=np.uint8)
|
70 |
+
image = np.reshape(image, (args.imscale, args.imscale, 3))
|
71 |
+
image = Image.fromarray(image.astype('uint8'), 'RGB')
|
72 |
+
print (np.shape(image))
|
73 |
+
|
74 |
+
|
75 |
+
if __name__ == "__main__":
|
76 |
+
|
77 |
+
parser = argparse.ArgumentParser()
|
78 |
+
parser.add_argument('--root', type=str, default='path/to/recipe1m',
|
79 |
+
help='path to the recipe1m dataset')
|
80 |
+
parser.add_argument('--save_dir', type=str, default='../data',
|
81 |
+
help='path where the lmdbs will be saved')
|
82 |
+
parser.add_argument('--imscale', type=int, default=256,
|
83 |
+
help='size of images (will be rescaled and center cropped)')
|
84 |
+
parser.add_argument('--maxnumims', type=int, default=5,
|
85 |
+
help='maximum number of images to allow for each sample')
|
86 |
+
parser.add_argument('--suff', type=str, default='',
|
87 |
+
help='id of the vocabulary to use')
|
88 |
+
parser.add_argument('--test_only', dest='test_only', action='store_true')
|
89 |
+
parser.set_defaults(test_only=False)
|
90 |
+
args = parser.parse_args()
|
91 |
+
|
92 |
+
if not args.test_only:
|
93 |
+
main(args)
|
94 |
+
test(args)
|
src/utils/metrics.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
|
3 |
+
import sys
|
4 |
+
import time
|
5 |
+
import math
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from torch.nn.modules.loss import _WeightedLoss
|
11 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
12 |
+
map_loc = None if torch.cuda.is_available() else 'cpu'
|
13 |
+
|
14 |
+
|
15 |
+
class MaskedCrossEntropyCriterion(_WeightedLoss):
|
16 |
+
|
17 |
+
def __init__(self, ignore_index=[-100], reduce=None):
|
18 |
+
super(MaskedCrossEntropyCriterion, self).__init__()
|
19 |
+
self.padding_idx = ignore_index
|
20 |
+
self.reduce = reduce
|
21 |
+
|
22 |
+
def forward(self, outputs, targets):
|
23 |
+
lprobs = nn.functional.log_softmax(outputs, dim=-1)
|
24 |
+
lprobs = lprobs.view(-1, lprobs.size(-1))
|
25 |
+
|
26 |
+
for idx in self.padding_idx:
|
27 |
+
# remove padding idx from targets to allow gathering without error (padded entries will be suppressed later)
|
28 |
+
targets[targets == idx] = 0
|
29 |
+
|
30 |
+
nll_loss = -lprobs.gather(dim=-1, index=targets.unsqueeze(1))
|
31 |
+
if self.reduce:
|
32 |
+
nll_loss = nll_loss.sum()
|
33 |
+
|
34 |
+
return nll_loss.squeeze()
|
35 |
+
|
36 |
+
|
37 |
+
def softIoU(out, target, e=1e-6, sum_axis=1):
|
38 |
+
|
39 |
+
num = (out*target).sum(sum_axis, True)
|
40 |
+
den = (out+target-out*target).sum(sum_axis, True) + e
|
41 |
+
iou = num / den
|
42 |
+
|
43 |
+
return iou
|
44 |
+
|
45 |
+
|
46 |
+
def update_error_types(error_types, y_pred, y_true):
|
47 |
+
|
48 |
+
error_types['tp_i'] += (y_pred * y_true).sum(0).cpu().data.numpy()
|
49 |
+
error_types['fp_i'] += (y_pred * (1-y_true)).sum(0).cpu().data.numpy()
|
50 |
+
error_types['fn_i'] += ((1-y_pred) * y_true).sum(0).cpu().data.numpy()
|
51 |
+
error_types['tn_i'] += ((1-y_pred) * (1-y_true)).sum(0).cpu().data.numpy()
|
52 |
+
|
53 |
+
error_types['tp_all'] += (y_pred * y_true).sum().item()
|
54 |
+
error_types['fp_all'] += (y_pred * (1-y_true)).sum().item()
|
55 |
+
error_types['fn_all'] += ((1-y_pred) * y_true).sum().item()
|
56 |
+
|
57 |
+
|
58 |
+
def compute_metrics(ret_metrics, error_types, metric_names, eps=1e-10, weights=None):
|
59 |
+
|
60 |
+
if 'accuracy' in metric_names:
|
61 |
+
ret_metrics['accuracy'].append(np.mean((error_types['tp_i'] + error_types['tn_i']) / (error_types['tp_i'] + error_types['fp_i'] + error_types['fn_i'] + error_types['tn_i'])))
|
62 |
+
if 'jaccard' in metric_names:
|
63 |
+
ret_metrics['jaccard'].append(error_types['tp_all'] / (error_types['tp_all'] + error_types['fp_all'] + error_types['fn_all'] + eps))
|
64 |
+
if 'dice' in metric_names:
|
65 |
+
ret_metrics['dice'].append(2*error_types['tp_all'] / (2*(error_types['tp_all'] + error_types['fp_all'] + error_types['fn_all']) + eps))
|
66 |
+
if 'f1' in metric_names:
|
67 |
+
pre = error_types['tp_i'] / (error_types['tp_i'] + error_types['fp_i'] + eps)
|
68 |
+
rec = error_types['tp_i'] / (error_types['tp_i'] + error_types['fn_i'] + eps)
|
69 |
+
f1_perclass = 2*(pre * rec) / (pre + rec + eps)
|
70 |
+
if 'f1_ingredients' not in ret_metrics.keys():
|
71 |
+
ret_metrics['f1_ingredients'] = [np.average(f1_perclass, weights=weights)]
|
72 |
+
else:
|
73 |
+
ret_metrics['f1_ingredients'].append(np.average(f1_perclass, weights=weights))
|
74 |
+
|
75 |
+
pre = error_types['tp_all'] / (error_types['tp_all'] + error_types['fp_all'] + eps)
|
76 |
+
rec = error_types['tp_all'] / (error_types['tp_all'] + error_types['fn_all'] + eps)
|
77 |
+
f1 = 2*(pre * rec) / (pre + rec + eps)
|
78 |
+
ret_metrics['f1'].append(f1)
|
src/utils/output_ing.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
|
3 |
+
replace_dict = {' .': '.',
|
4 |
+
' ,': ',',
|
5 |
+
' ;': ';',
|
6 |
+
' :': ':',
|
7 |
+
'( ': '(',
|
8 |
+
' )': ')',
|
9 |
+
" '": "'"}
|
10 |
+
|
11 |
+
def get_ingrs(ids, ingr_vocab_list):
|
12 |
+
gen_ingrs = []
|
13 |
+
for ingr_idx in ids:
|
14 |
+
ingr_name = ingr_vocab_list[ingr_idx]
|
15 |
+
if ingr_name == '<pad>':
|
16 |
+
break
|
17 |
+
gen_ingrs.append(ingr_name)
|
18 |
+
return gen_ingrs
|
19 |
+
|
20 |
+
|
21 |
+
def prepare_output(gen_ingrs, ingr_vocab_list):
|
22 |
+
|
23 |
+
if gen_ingrs is not None:
|
24 |
+
gen_ingrs = get_ingrs(gen_ingrs, ingr_vocab_list)
|
25 |
+
|
26 |
+
outs = {'ingrs': gen_ingrs}
|
27 |
+
|
28 |
+
return outs
|
src/utils/output_utils.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
|
3 |
+
replace_dict = {' .': '.',
|
4 |
+
' ,': ',',
|
5 |
+
' ;': ';',
|
6 |
+
' :': ':',
|
7 |
+
'( ': '(',
|
8 |
+
' )': ')',
|
9 |
+
" '": "'"}
|
10 |
+
|
11 |
+
|
12 |
+
def get_recipe(ids, vocab):
|
13 |
+
toks = []
|
14 |
+
for id_ in ids:
|
15 |
+
toks.append(vocab[id_])
|
16 |
+
return toks
|
17 |
+
|
18 |
+
|
19 |
+
def get_ingrs(ids, ingr_vocab_list):
|
20 |
+
gen_ingrs = []
|
21 |
+
for ingr_idx in ids:
|
22 |
+
ingr_name = ingr_vocab_list[ingr_idx]
|
23 |
+
if ingr_name == '<pad>':
|
24 |
+
break
|
25 |
+
gen_ingrs.append(ingr_name)
|
26 |
+
return gen_ingrs
|
27 |
+
|
28 |
+
|
29 |
+
def prettify(toks, replace_dict):
|
30 |
+
toks = ' '.join(toks)
|
31 |
+
toks = toks.split('<end>')[0]
|
32 |
+
sentences = toks.split('<eoi>')
|
33 |
+
|
34 |
+
pretty_sentences = []
|
35 |
+
for sentence in sentences:
|
36 |
+
sentence = sentence.strip()
|
37 |
+
sentence = sentence.capitalize()
|
38 |
+
for k, v in replace_dict.items():
|
39 |
+
sentence = sentence.replace(k, v)
|
40 |
+
if sentence != '':
|
41 |
+
pretty_sentences.append(sentence)
|
42 |
+
return pretty_sentences
|
43 |
+
|
44 |
+
|
45 |
+
def colorized_list(ingrs, ingrs_gt, colorize=False):
|
46 |
+
if colorize:
|
47 |
+
colorized_list = []
|
48 |
+
for word in ingrs:
|
49 |
+
if word in ingrs_gt:
|
50 |
+
word = '\033[1;30;42m ' + word + ' \x1b[0m'
|
51 |
+
else:
|
52 |
+
word = '\033[1;30;41m ' + word + ' \x1b[0m'
|
53 |
+
colorized_list.append(word)
|
54 |
+
return colorized_list
|
55 |
+
else:
|
56 |
+
return ingrs
|
57 |
+
|
58 |
+
|
59 |
+
def prepare_output(ids, gen_ingrs, ingr_vocab_list, vocab):
|
60 |
+
|
61 |
+
toks = get_recipe(ids, vocab)
|
62 |
+
is_valid = True
|
63 |
+
reason = 'All ok.'
|
64 |
+
try:
|
65 |
+
cut = toks.index('<end>')
|
66 |
+
toks_trunc = toks[0:cut]
|
67 |
+
except:
|
68 |
+
toks_trunc = toks
|
69 |
+
is_valid = False
|
70 |
+
reason = 'no eos found'
|
71 |
+
|
72 |
+
# repetition score
|
73 |
+
score = float(len(set(toks_trunc))) / float(len(toks_trunc))
|
74 |
+
|
75 |
+
prev_word = ''
|
76 |
+
found_repeat = False
|
77 |
+
for word in toks_trunc:
|
78 |
+
if prev_word == word and prev_word != '<eoi>':
|
79 |
+
found_repeat = True
|
80 |
+
break
|
81 |
+
prev_word = word
|
82 |
+
|
83 |
+
toks = prettify(toks, replace_dict)
|
84 |
+
title = toks[0]
|
85 |
+
toks = toks[1:]
|
86 |
+
|
87 |
+
if gen_ingrs is not None:
|
88 |
+
gen_ingrs = get_ingrs(gen_ingrs, ingr_vocab_list)
|
89 |
+
|
90 |
+
if score <= 0.3:
|
91 |
+
reason = 'Diversity score.'
|
92 |
+
is_valid = False
|
93 |
+
elif len(toks) != len(set(toks)):
|
94 |
+
reason = 'Repeated instructions.'
|
95 |
+
is_valid = False
|
96 |
+
elif found_repeat:
|
97 |
+
reason = 'Found word repeat.'
|
98 |
+
is_valid = False
|
99 |
+
|
100 |
+
valid = {'is_valid': is_valid, 'reason': reason, 'score': score}
|
101 |
+
outs = {'title': title, 'recipe': toks, 'ingrs': gen_ingrs}
|
102 |
+
|
103 |
+
return outs, valid
|
src/utils/tb_visualizer.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import ntpath
|
6 |
+
import time
|
7 |
+
import glob
|
8 |
+
from scipy.misc import imresize
|
9 |
+
import torchvision.utils as vutils
|
10 |
+
from operator import itemgetter
|
11 |
+
from tensorboardX import SummaryWriter
|
12 |
+
|
13 |
+
|
14 |
+
class Visualizer():
|
15 |
+
def __init__(self, checkpoints_dir, name):
|
16 |
+
self.win_size = 256
|
17 |
+
self.name = name
|
18 |
+
self.saved = False
|
19 |
+
self.checkpoints_dir = checkpoints_dir
|
20 |
+
self.ncols = 4
|
21 |
+
|
22 |
+
# remove existing
|
23 |
+
for filename in glob.glob(self.checkpoints_dir+"/events*"):
|
24 |
+
os.remove(filename)
|
25 |
+
self.writer = SummaryWriter(checkpoints_dir)
|
26 |
+
|
27 |
+
def reset(self):
|
28 |
+
self.saved = False
|
29 |
+
|
30 |
+
# images: (b, c, 0, 1) array of images
|
31 |
+
def image_summary(self, mode, epoch, images):
|
32 |
+
images = vutils.make_grid(images, normalize=True, scale_each=True)
|
33 |
+
self.writer.add_image('{}/Image'.format(mode), images, epoch)
|
34 |
+
|
35 |
+
# text: type: ingredients/recipe
|
36 |
+
def text_summary(self, mode, epoch, type, text, vocabulary, gt=True, max_length=20):
|
37 |
+
for i, el in enumerate(text): # text_list
|
38 |
+
if not gt: # we are printing a sample
|
39 |
+
idx = el.nonzero().squeeze() + 1
|
40 |
+
else:
|
41 |
+
idx = el # we are printing the ground truth
|
42 |
+
|
43 |
+
words_list = itemgetter(*idx)(vocabulary)
|
44 |
+
|
45 |
+
if len(words_list) <= max_length:
|
46 |
+
self.writer.add_text('{}/{}_{}_{}'.format(mode, type, i, 'gt' if gt else 'prediction'),
|
47 |
+
', '.join(filter(lambda x: x != '<pad>', words_list)), epoch)
|
48 |
+
else:
|
49 |
+
self.writer.add_text('{}/{}_{}_{}'.format(mode, type, i, 'gt' if gt else 'prediction'),
|
50 |
+
'Number of sampled ingredients is too big: {}'.format(len(words_list)), epoch)
|
51 |
+
|
52 |
+
# losses: dictionary of error labels and values
|
53 |
+
def scalar_summary(self, mode, epoch, **args):
|
54 |
+
for k, v in args.items():
|
55 |
+
self.writer.add_scalar('{}/{}'.format(mode, k), v, epoch)
|
56 |
+
|
57 |
+
self.writer.export_scalars_to_json("{}/tensorboard_all_scalars.json".format(self.checkpoints_dir))
|
58 |
+
|
59 |
+
def histo_summary(self, model, step):
|
60 |
+
"""Log a histogram of the tensor of values."""
|
61 |
+
|
62 |
+
for name, param in model.named_parameters():
|
63 |
+
self.writer.add_histogram(name, param, step)
|
64 |
+
|
65 |
+
def close(self):
|
66 |
+
self.writer.close()
|