davidaf3's picture
Fixed typo
0832cb7
raw
history blame
7.22 kB
from typing import Dict, List, Any
from PIL import Image
from tfing import TFIng
from tfport import TFPort, get_look_ahead_mask, get_padding_mask
import os
import json
import tensorflow as tf
import numpy as np
class PreTrainedPipeline():
def __init__(self, path=""):
crop_size = (224, 224)
embed_dim = 256
num_layers = 3
seq_length = 20
hidden_dim = 1024
num_heads = 8
self.nutr_names = ('energy', 'fat', 'protein', 'carbs')
with open(os.path.join(path, "ingredients_metadata.json"), encoding='UTF-8') as f:
self.ingredients = json.load(f)
self.ing_names = {ing['name']: int(ing_id) for ing_id, ing in self.ingredients.items()}
self.vocab_size = len(self.ingredients) + 3
self.seq_length = seq_length
self.tfing = TFIng(
crop_size,
embed_dim,
num_layers,
seq_length,
hidden_dim,
num_heads,
self.vocab_size
)
self.tfing.compile()
self.tfing((tf.zeros((1, 224, 224, 3)), tf.zeros((1, seq_length))))
self.tfing.load_weights(os.path.join(path, 'tfing.h5'))
self.tfport = TFPort(
crop_size,
embed_dim,
num_layers,
num_layers,
seq_length,
seq_length,
hidden_dim,
num_heads,
self.vocab_size
)
self.tfport.compile()
self.tfport((tf.zeros((1, 224, 224, 3)), tf.zeros((1, seq_length)), tf.zeros((1, seq_length))))
self.tfport.load_weights(os.path.join(path, 'tfport.h5'))
def __call__(self, inputs: "Image.Image") -> List[Dict[str, Any]]:
image = tf.keras.preprocessing.image.img_to_array(inputs)
height = tf.shape(image)[0]
width = tf.shape(image)[1]
if width > height:
image = tf.image.resize(image, (self.img_size, int(float(self.img_size * width) / float(height))))
else:
image = tf.image.resize(image, (int(float(self.img_size * height) / float(width)), self.img_size))
image = tf.keras.applications.inception_v3.preprocess_input(image)
image = tf.keras.layers.CenterCrop(*self.crop_size)(image)
prediction = self.predict(image)
return [
{
"label": prediction['ingredients'][i],
"score": prediction['portions'][i]
}
for i in range(len(prediction['ingredients']))
]
def encode_image(self, image):
encoder_out = self.tfing.encoder(image)
encoder_out = self.tfing.conv(encoder_out)
encoder_out = tf.reshape(
encoder_out,
(tf.shape(encoder_out)[0], -1, tf.shape(encoder_out)[3])
)
return encoder_out
def encode_ingredients(self, ingredients, padding_mask):
return self.tfport.ingredient_encoder(ingredients, padding_mask)
def decode_ingredients(self, encoded_img, decoder_in):
decoder_outputs = self.tfing.decoder(decoder_in, encoded_img)
output = self.tfing.linear(decoder_outputs)
return output + self.tfing.get_replacement_mask(decoder_in)
def decode_portions(self, encoded_img, encoded_ingr, decoder_in, padding_mask):
encoder_outputs = tf.concat([encoded_img, encoded_ingr], axis=1)
img_mask = tf.ones((tf.shape(encoded_img)[0], 1, tf.shape(encoded_img)[1]), dtype=tf.int32)
padding_mask = tf.concat([img_mask, padding_mask], axis=2)
look_ahead_mask = get_look_ahead_mask(decoder_in)
x = self.tfport.portion_embedding(decoder_in)
for i in range(len(self.tfport.decoder_layers)):
x = self.tfport.decoder_layers[i](x, encoder_outputs, look_ahead_mask, padding_mask=padding_mask)
x = self.tfport.linear(x)
return tf.squeeze(x)
def predict_ingredients(self, encoded_img, known_ing=None):
predicted = np.zeros((1, self.seq_length + 1), dtype=int)
predicted[0, 0] = self.vocab_size - 2
start_index = 0
if known_ing:
predicted[0, 1:len(known_ing) + 1] = known_ing
start_index = len(known_ing)
for i in range(start_index, self.seq_length):
decoded = self.decode_ingredients(encoded_img, predicted[:, :-1])
next_token = int(np.argmax(decoded[0, i]))
predicted[0, i + 1] = next_token
if next_token == self.vocab_size - 1:
return predicted[0, 1:]
if i == self.seq_length - 1:
predicted[0, i + 1] = self.vocab_size - 1
return predicted[0, 1:]
def predict_portions(self, encoded_image, ingredients):
predicted = np.zeros((1, self.seq_length + 1), dtype=float)
predicted[0, 0] = -1
padding_mask = get_padding_mask(ingredients)
encoded_ingr = self.encode_ingredients(ingredients, padding_mask)
for i in range(self.seq_length):
if ingredients[0, i] == self.vocab_size - 1:
return predicted[0, 1:]
next_proportion = float(
self.decode_portions(
encoded_image,
encoded_ingr,
predicted[:, :-1],
padding_mask
)[i]
)
predicted[0, i + 1] = next_proportion
return predicted[0, 1:]
def process_ingredients(self, ingredients):
processed = []
for ingredient in ingredients.split('\n'):
stripped = ingredient.strip()
if stripped == '.':
return processed, True
if stripped in self.ing_names:
processed.append(self.ing_names[stripped])
return processed, False
def predict(self, image, known_ing=None):
encoded_image = self.encode_image(image[tf.newaxis, :])
known_ing, skip_ing = self.process_ingredients(known_ing)\
if known_ing else (None, False)
if not skip_ing:
ingredients = self.predict_ingredients(encoded_image, known_ing=known_ing)
else:
ingredients = known_ing[:self.seq_length - 1]
ingredients.append(self.vocab_size - 1)
ingredients = np.pad(ingredients, (0, self.seq_length - len(ingredients)))
readable_ingredients = [
self.ingredients[str(token)]['name'] for token in ingredients
if token != 0 and token != self.vocab_size - 1
]
portions = self.predict_portions(encoded_image, ingredients[tf.newaxis, :])\
if len(readable_ingredients) > 1 else [100]
portions_slice = portions[:len(readable_ingredients)]
scale = 100 / sum(portions_slice)
return {
'ingredients': readable_ingredients,
'portions': [portion * scale for portion in portions_slice],
'nutrition': {
name: sum(
self.ingredients[str(ingredients[i])][name] * portions[i] / 100
for i in range(len(readable_ingredients))
) for name in self.nutr_names
}
}