File size: 5,951 Bytes
4b1ee17
 
 
 
 
 
 
 
 
 
 
 
d927e44
4b1ee17
 
1c0510f
 
 
d927e44
1c0510f
 
 
 
4b1ee17
 
 
 
 
 
 
 
 
 
7301ba2
4b1ee17
 
 
 
 
d927e44
4b1ee17
 
 
 
 
 
 
 
 
 
1c0510f
a5fc389
4b1ee17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6bc454
 
 
0d33000
f6bc454
0d33000
f6bc454
 
0d33000
dc4cbf0
4b1ee17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c0510f
4b1ee17
a43c327
1c0510f
4b1ee17
a43c327
 
4b1ee17
 
1c0510f
 
 
 
4b1ee17
 
 
 
 
 
 
 
 
 
1610241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e02b592
1610241
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os
import numpy as np
from matplotlib import rcParams
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model, Model
from tensorflow.keras.utils import load_img, save_img, img_to_array
from tensorflow.keras.applications.vgg19 import preprocess_input
from tensorflow.keras.layers import GlobalAveragePooling2D
from pymilvus import connections, Collection, utility
from requests import get
import streamlit as st
import zipfile


# unzip vegetable images
def unzip_images():
    with zipfile.ZipFile("Vegetable Images.zip", 'r') as zip_ref:
        zip_ref.extractall('.')
        print('unzipped images')

if not os.path.exists('Vegetable Images/'):
    unzip_images()


class ImageVectorizer:
    '''
    Get vector representation of an image using VGG19 model fine tuned on vegetable images for classification
    '''
    
    def __init__(self):
        self.__model = self.get_model()
    
    @staticmethod
    @st.cache_resource
    def get_model():
        model = load_model('vegetable_classification_model_vgg.h5') # loading saved VGG model finetuned on vegetable images for classification
        top = model.get_layer('block5_pool').output
        top = GlobalAveragePooling2D()(top)
        model = Model(inputs=model.input, outputs=top)
        print('loaded model')
        return model
    
    def vectorize(self, img_path: str):
        model = self.__model
        test_image = load_img(img_path, color_mode="rgb", target_size=(224, 224))
        test_image = img_to_array(test_image)
        test_image = preprocess_input(test_image)
        test_image = np.array([test_image])
        return model(test_image).numpy()[0]


@st.cache_resource
def get_milvus_collection():
    uri = os.environ.get("URI")
    token = os.environ.get("TOKEN")
    connections.connect("default", uri=uri, token=token)
    print(f"Connected to DB")
    collection_name = os.environ.get("COLLECTION_NAME")
    collection = Collection(name=collection_name)
    collection.load()
    return collection


def plot_images(input_image_path: str, similar_img_paths: list):
    # plotting similar images
    rows = 5 # rows in subplots
    cols = 3 # columns in subplots
    fig, ax = plt.subplots(rows, cols, figsize=(12, 20))
    r = 0
    c = 0
    for i in range(rows*cols):
        sim_image = load_img(similar_img_paths[i], color_mode="rgb", target_size=(224, 224))
        ax[r,c].axis("off")
        ax[r,c].imshow(sim_image)
        c += 1
        if c == cols:
            c = 0
            r += 1
    plt.subplots_adjust(wspace=0.01, hspace=0.01)

    # display input image
    rcParams.update({'figure.autolayout': True})
    input_image = load_img(input_image_path, color_mode="rgb", target_size=(224, 224))
    with placeholder.container():
        st.markdown('<p style="font-size: 20px; font-weight: bold">Input image</p>', unsafe_allow_html=True)
        st.image(input_image)
        
        st.write('  \n')
        
        # display similar images
        st.markdown('<p style="font-size: 20px; font-weight: bold">Similar images</p>', unsafe_allow_html=True)
        st.pyplot(fig)
        

def find_similar_images(img_path: str, top_n: int=15):
    search_params = {"metric_type": "L2"}
    search_vec = vectorizer.vectorize(img_path)
    result = collection.search([search_vec],
                                anns_field='image_vector', # annotation field specified in the schema definition
                                param=search_params,
                                limit=top_n,
                                guarantee_timestamp=1, 
                                output_fields=['image_path']) # which fields to return in output
    
    output_dict = {"input_image_path": img_path, "similar_image_paths": [hit.entity.get('image_path') for hits in result for hit in hits]} 
    plot_images(output_dict['input_image_path'], output_dict['similar_image_paths'])


def delete_file(path_: str):
    if os.path.exists(path_):
        os.remove(path_)

@st.cache_resource
def get_upload_path():
    upload_file_path = os.path.join('.', 'uploads')
    if not os.path.exists(upload_file_path):
        os.makedirs(upload_file_path)
    upload_filename = "input.jpg"
    upload_file_path = os.path.join(upload_file_path, upload_filename)
    return upload_file_path

def process_input_image(img_url):
    upload_file_path = get_upload_path()
    headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/102.0.0.0 Safari/537.36'}
    r = get(img_url, headers=headers)
    with open(upload_file_path, "wb") as file:
        file.write(r.content)
    return upload_file_path


vectorizer = ImageVectorizer()
collection = get_milvus_collection()

try:
    st.markdown("<h3>Find Similar Vegetable Images</h3>", unsafe_allow_html=True)
    desc = '''<p style="font-size: 15px;">Allowed Vegetables: Broad Bean, Bitter Gourd, Bottle Gourd, 
    Green Round Brinjal, Broccoli, Cabbage, Capsicum, Carrot, Cauliflower, Cucumber, 
    Raw Papaya, Potato, Green Pumpkin, Radish, Tomato.
    </p> 
    <p style="font-size: 13px;">Image embeddings are extracted from a fine-tuned VGG model. The model is fine-tuned on <a href="https://www.kaggle.com/datasets/misrakahmed/vegetable-image-dataset" target="_blank">images</a> clicked using a mobile phone camera. 
    Embeddings of 20,000 vegetable images are stored in Milvus vector database. Embeddings of the input image are computed and 15 most similar images (based on L2 distance) are displayed.</p>
    '''
    st.markdown(desc, unsafe_allow_html=True)
    img_url = st.text_input("Paste the image URL of a vegetable and hit Enter:", "")
    placeholder = st.empty()
    if img_url:
        placeholder.empty()
        img_path = process_input_image(img_url)
        find_similar_images(img_path, 15)
        delete_file(img_path)
except Exception as e:
    st.error(f'An unexpected error occured:  \n{e}')