ksvmuralidhar's picture
Update app.py
7301ba2 verified
raw
history blame
5.95 kB
import os
import numpy as np
from matplotlib import rcParams
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model, Model
from tensorflow.keras.utils import load_img, save_img, img_to_array
from tensorflow.keras.applications.vgg19 import preprocess_input
from tensorflow.keras.layers import GlobalAveragePooling2D
from pymilvus import connections, Collection, utility
from requests import get
from shutil import rmtree
import streamlit as st
import zipfile
import logging
# unzip vegetable images
def unzip_images():
with zipfile.ZipFile("Vegetable Images.zip", 'r') as zip_ref:
zip_ref.extractall('.')
logging.debug('unzipped images')
if not os.path.exists('Vegetable Images/'):
unzip_images()
class ImageVectorizer:
'''
Get vector representation of an image using VGG19 model fine tuned on vegetable images for classification
'''
def __init__(self):
self.__model = self.get_model()
@staticmethod
@st.cache_resource
def get_model():
model = load_model('vegetable_classification_model_vgg.h5') # loading saved VGG model finetuned on vegetable images for classification
top = model.get_layer('block5_pool').output
top = GlobalAveragePooling2D()(top)
model = Model(inputs=model.input, outputs=top)
return model
def vectorize(self, img_path: str):
model = self.__model
test_image = load_img(img_path, color_mode="rgb", target_size=(224, 224))
test_image = img_to_array(test_image)
test_image = preprocess_input(test_image)
test_image = np.array([test_image])
return model(test_image).numpy()[0]
@st.cache_resource
def get_milvus_collection():
uri = os.environ.get("URI")
token = os.environ.get("TOKEN")
connections.connect("default", uri=uri, token=token)
print(f"Connected to DB")
collection_name = os.environ.get("COLLECTION_NAME")
collection = Collection(name=collection_name)
collection.load()
return collection
def plot_images(input_image_path: str, similar_img_paths: list):
# plotting similar images
rows = 5 # rows in subplots
cols = 3 # columns in subplots
fig, ax = plt.subplots(rows, cols, figsize=(12, 20))
r = 0
c = 0
for i in range(rows*cols):
sim_image = load_img(similar_img_paths[i], color_mode="rgb", target_size=(224, 224))
ax[r,c].axis("off")
ax[r,c].imshow(sim_image)
c += 1
if c == cols:
c = 0
r += 1
plt.subplots_adjust(wspace=0.01, hspace=0.01)
# display input image
rcParams.update({'figure.autolayout': True})
input_image = load_img(input_image_path, color_mode="rgb", target_size=(224, 224))
with placeholder.container():
st.markdown('<p style="font-size: 20px; font-weight: bold">Input image</p>', unsafe_allow_html=True)
st.image(input_image)
st.write(' \n')
# display similar images
st.markdown('<p style="font-size: 20px; font-weight: bold">Similar images</p>', unsafe_allow_html=True)
st.pyplot(fig)
def find_similar_images(img_path: str, top_n: int=15):
search_params = {"metric_type": "L2"}
search_vec = vectorizer.vectorize(img_path)
result = collection.search([search_vec],
anns_field='image_vector', # annotation field specified in the schema definition
param=search_params,
limit=top_n,
guarantee_timestamp=1,
output_fields=['image_path']) # which fields to return in output
output_dict = {"input_image_path": img_path, "similar_image_paths": [hit.entity.get('image_path') for hits in result for hit in hits]}
plot_images(output_dict['input_image_path'], output_dict['similar_image_paths'])
def delete_file(path_: str):
if os.path.exists(path_):
os.remove(path_)
@st.cache_resource
def get_upload_path():
upload_file_path = os.path.join('.', 'uploads')
if not os.path.exists(upload_file_path):
os.makedirs(upload_file_path)
upload_filename = "input.jpg"
upload_file_path = os.path.join(upload_file_path, upload_filename)
return upload_file_path
def process_input_image(img_url):
upload_file_path = get_upload_path()
headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/102.0.0.0 Safari/537.36'}
r = get(img_url, headers=headers)
with open(upload_file_path, "wb") as file:
file.write(r.content)
return upload_file_path
vectorizer = ImageVectorizer()
collection = get_milvus_collection()
try:
st.markdown("<h3>Find Similar Vegetable Images</h3>", unsafe_allow_html=True)
desc = '''<p style="font-size: 15px;">Allowed Vegetables: Broad Bean, Bitter Gourd, Bottle Gourd,
Green Round Brinjal, Broccoli, Cabbage, Capsicum, Carrot, Cauliflower, Cucumber,
Raw Papaya, Potato, Green Pumpkin, Radish, Tomato.
</p>
<p style="font-size: 13px;">Image embeddings are extracted from a fine-tuned VGG model. The model is fine-tuned on <a href="https://www.kaggle.com/datasets/misrakahmed/vegetable-image-dataset" target="_blank">images</a> clicked using a mobile phone camera.
Embeddings of 20,000 vegetable images are stored in Milvus vector database. Embeddings of the input image are computed and 15 most similar images (based on L2 distance) are displayed.</p>
'''
st.markdown(desc, unsafe_allow_html=True)
img_url = st.text_input("Paste the image URL of a vegetable and hit Enter:", "")
placeholder = st.empty()
if img_url:
placeholder.empty()
img_path = process_input_image(img_url)
find_similar_images(img_path, 15)
delete_file(img_path)
except Exception as e:
st.error(f'An unexpected error occured: \n{e}')