laudavid commited on
Commit
18edec3
β€’
1 Parent(s): c847bd4

modified object detection example

Browse files
data/dior_show/{dior1.jpg β†’ images/dior1.jpg} RENAMED
File without changes
data/dior_show/{dior2.jpg β†’ images/dior2.jpg} RENAMED
File without changes
data/dior_show/{dior3.jpg β†’ images/dior3.jpg} RENAMED
File without changes
data/dior_show/{dior4.jpg β†’ images/dior4.jpg} RENAMED
File without changes
data/dior_show/results/dior1_results.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d8e9ca5b6ebd77da5cf893e062316121e3145bca6afc993e350aec9780b64fd
3
+ size 5936769
data/dior_show/results/dior2_results.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7f562867db3dc0cc8fe8d49fd87ce2aadae96715e0dabe0a9dd24cc913a1c31
3
+ size 5936769
data/dior_show/results/dior3_results.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:155efde06760dc3b86684e4cfd96209cb528df9a7b7d2696df49863f23d1e4e6
3
+ size 5936769
data/dior_show/results/dior4_results.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:39d60dcf806580a5092f3f4104fc92101e18863524f34fe2bf909ba66e884973
3
+ size 5936769
pages/object_detection.py CHANGED
@@ -4,14 +4,14 @@ import streamlit as st
4
  import matplotlib.pyplot as plt
5
  import pandas as pd
6
  import numpy as np
7
- #import altair as alt
8
  import plotly.express as px
9
-
 
10
 
11
  from PIL import Image
12
  from transformers import YolosFeatureExtractor, YolosForObjectDetection
13
  from torchvision.transforms import ToTensor, ToPILImage
14
- #from utils import load_model_huggingface
15
 
16
 
17
  st.set_page_config(layout="wide")
@@ -80,8 +80,7 @@ def plot_results(pil_img, prob, boxes):
80
  ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
81
  fill=False, color=c, linewidth=3))
82
  ax.text(xmin, ymin, f"{idx_to_text(cl)}", fontsize=10,
83
- bbox=dict(facecolor=c, alpha=0.8))
84
-
85
  plt.axis('off')
86
 
87
  plt.savefig("results_od.png",
@@ -100,18 +99,6 @@ def return_probas(outputs, threshold):
100
  return probas, keep
101
 
102
 
103
-
104
- # def visualize_predictions(image, outputs, threshold):
105
- # # keep only predictions with confidence >= threshold
106
- # # convert predicted boxes from [0; 1] to image scales
107
- # bboxes_scaled = rescale_bboxes(outputs.pred_boxes[0, keep].cpu(), image.size)
108
-
109
- # # plot results
110
- # plot_results(image, probas[keep], bboxes_scaled)
111
-
112
- # return probas[keep]
113
-
114
-
115
  def visualize_probas(probas, threshold, colors):
116
  label_df = pd.DataFrame({"label":probas.max(-1).indices.detach().numpy(),
117
  "proba":probas.max(-1).values.detach().numpy()})
@@ -136,9 +123,10 @@ def visualize_probas(probas, threshold, colors):
136
  color="Item", title="Probability scores")
137
  st.plotly_chart(fig, use_container_width=True)
138
 
139
- # chart = alt.Chart(top_label_df_agg).mark_bar().encode(x="proba", y="label",
140
- # color=alt.Color('colors:N', scale=None)).interactive()
141
- # st.altair_chart(chart)
 
142
 
143
 
144
 
@@ -182,49 +170,75 @@ st.info("""In this use case, we are going to identify and locate different artic
182
 
183
  st.markdown(" ")
184
 
185
- images_dior = [os.path.join("data/dior_show",url) for url in os.listdir("data/dior_show") if url != "results"]
186
  columns_img = st.columns(4)
187
  for img, col in zip(images_dior,columns_img):
188
  with col:
189
  st.image(img)
190
 
191
-
192
  st.markdown(" ")
193
 
194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  ############## SELECT AN IMAGE ###############
196
 
197
- st.markdown("#### Select an image πŸ–ΌοΈ")
198
  #st.markdown("""**Select an image that you wish to run the Object Detection model on.**""")
199
 
200
-
201
  image_ = None
202
- select_image_box = st.radio(
203
- "**Select the image you wish to run the model on**",
204
- ["Choose an existing image", "Load your own image"],
205
- index=None,)# #label_visibility="collapsed")
206
-
207
- if select_image_box == "Choose an existing image":
208
- fashion_images_path = r"data/dior_show"
209
- list_images = os.listdir(fashion_images_path)
210
- image_ = st.selectbox("", list_images, label_visibility="collapsed")
 
 
 
 
 
 
211
 
212
- if image_ is not None:
213
- image_ = os.path.join(fashion_images_path,image_)
214
- st.markdown("You've selected the following image:")
215
- st.image(image_, width=300)
216
-
217
- elif select_image_box == "Load your own image":
218
- image_ = st.file_uploader("Load an image here",
219
- key="OD_dior", type=['jpg','jpeg','png'], label_visibility="collapsed")
220
 
221
- st.warning("""**Note**: The model tends to perform better with images of people/clothing items facing forward.
222
- Choose this type of image if you want optimal results.""")
223
- st.warning("""**Note:** The model was trained to detect clothing items on a single person.
224
- If your image contains more than one person, the model won't detect the items of the other persons.""")
225
 
226
- if image_ is not None:
227
- st.image(Image.open(image_), width=300)
228
 
229
 
230
  st.markdown(" ")
@@ -234,46 +248,41 @@ st.markdown(" ")
234
 
235
  ########## SELECT AN ELEMENT TO DETECT ##################
236
 
237
- cats = ['shirt, blouse', 'top, t-shirt, sweatshirt', 'sweater', 'cardigan', 'jacket', 'vest', 'pants', 'shorts', 'skirt', 'coat', 'dress', 'jumpsuit',
238
- 'cape', 'glasses', 'hat', 'headband, head covering, hair accessory', 'tie', 'glove', 'watch', 'belt', 'leg warmer', 'tights, stockings', 'sock', 'shoe', 'bag, wallet', 'scarf', 'umbrella', 'hood', 'collar',
239
- 'lapel', 'epaulette', 'sleeve', 'pocket', 'neckline', 'buckle', 'zipper', 'applique', 'bead', 'bow', 'flower', 'fringe', 'ribbon', 'rivet', 'ruffle', 'sequin', 'tassel']
240
 
241
  dict_cats = dict(zip(np.arange(len(cats)), cats))
242
 
243
- st.markdown("#### Choose the elements you want to detect πŸ‘‰")
244
 
245
- # Select one or more elements to detect
246
- container = st.container()
247
- selected_options = None
248
- all = st.checkbox("Select all")
249
 
250
- if all:
251
- selected_options = container.multiselect("**Select one or more items**", cats, cats)
252
- else:
253
- selected_options = container.multiselect("**Select one or more items**", cats)
254
 
255
  #cats = selected_options
 
256
  dict_cats_final = {key:value for (key,value) in dict_cats.items() if value in selected_options}
257
 
258
 
259
- st.markdown(" ")
260
- st.markdown(" ")
261
 
262
 
263
 
264
  ############## SELECT A THRESHOLD ###############
265
 
266
- st.markdown("#### Define a threshold for predictions πŸ”Ž")
267
- st.markdown("""Object detection models assign to each element detected a **probability score**. <br>
268
- This score represents the model's belief in the accuracy of its prediction for a specific object.
269
- """, unsafe_allow_html=True)
270
 
271
- st.warning("**Note:** Objects that are assigned a lower score than the chosen threshold will be ignored in the final results.")
272
 
273
 
274
- _, col, _ = st.columns([0.2,0.6,0.2])
275
- with col:
276
- st.image("images/probability_od.png", caption="Example of object detection with probability scores")
277
 
278
  st.markdown(" ")
279
 
@@ -284,13 +293,18 @@ st.markdown("**Select a threshold** ")
284
 
285
  threshold = st.slider('**Select a threshold**', min_value=0.5, step=0.05, max_value=1.0, value=0.75, label_visibility="collapsed")
286
 
287
- if threshold < 0.6:
288
- st.error("""**Warning**: Selecting a low threshold (below 0.6) could lead the model to make errors and detect too many objects.""")
 
289
 
290
  st.write("You've selected a threshold at", threshold)
291
  st.markdown(" ")
292
 
293
 
 
 
 
 
294
  ############# RUN MODEL ################
295
 
296
  run_model = st.button("**Run the model**", type="primary")
@@ -299,19 +313,30 @@ if run_model:
299
  if image_ != None and selected_options != None and threshold!= None:
300
  with st.spinner('Wait for it...'):
301
  ## SELECT IMAGE
 
302
  image = Image.open(image_)
303
  image = fix_channels(ToTensor()(image))
304
 
305
  ## LOAD OBJECT DETECTION MODEL
306
  FEATURE_EXTRACTOR_PATH = "hustvl/yolos-small"
307
  MODEL_PATH = "valentinafeve/yolos-fashionpedia"
308
- feature_extractor, model = load_model(FEATURE_EXTRACTOR_PATH, MODEL_PATH)
309
- # feature_extractor = YolosFeatureExtractor.from_pretrained('hustvl/yolos-small')
310
- # model = YolosForObjectDetection.from_pretrained(MODEL)
311
 
312
- # RUN MODEL ON IMAGE
313
- inputs = feature_extractor(images=image, return_tensors="pt")
314
- outputs = model(**inputs)
 
 
 
 
 
 
 
 
 
 
 
 
315
  probas, keep = return_probas(outputs, threshold)
316
 
317
  st.markdown("#### See the results β˜‘οΈ")
 
4
  import matplotlib.pyplot as plt
5
  import pandas as pd
6
  import numpy as np
 
7
  import plotly.express as px
8
+ import pickle
9
+ import random
10
 
11
  from PIL import Image
12
  from transformers import YolosFeatureExtractor, YolosForObjectDetection
13
  from torchvision.transforms import ToTensor, ToPILImage
14
+ from annotated_text import annotated_text
15
 
16
 
17
  st.set_page_config(layout="wide")
 
80
  ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
81
  fill=False, color=c, linewidth=3))
82
  ax.text(xmin, ymin, f"{idx_to_text(cl)}", fontsize=10,
83
+ bbox=dict(facecolor=c, alpha=0.8))
 
84
  plt.axis('off')
85
 
86
  plt.savefig("results_od.png",
 
99
  return probas, keep
100
 
101
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  def visualize_probas(probas, threshold, colors):
103
  label_df = pd.DataFrame({"label":probas.max(-1).indices.detach().numpy(),
104
  "proba":probas.max(-1).values.detach().numpy()})
 
123
  color="Item", title="Probability scores")
124
  st.plotly_chart(fig, use_container_width=True)
125
 
126
+
127
+ cats = ['shirt, blouse', 'top, t-shirt, sweatshirt', 'sweater', 'cardigan', 'jacket', 'vest', 'pants', 'shorts', 'skirt', 'coat', 'dress', 'jumpsuit',
128
+ 'cape', 'glasses', 'hat', 'headband, head covering, hair accessory', 'tie', 'glove', 'watch', 'belt', 'leg warmer', 'tights, stockings', 'sock', 'shoe', 'bag, wallet', 'scarf', 'umbrella', 'hood', 'collar',
129
+ 'lapel', 'epaulette', 'sleeve', 'pocket', 'neckline', 'buckle', 'zipper', 'applique', 'bead', 'bow', 'flower', 'fringe', 'ribbon', 'rivet', 'ruffle', 'sequin', 'tassel']
130
 
131
 
132
 
 
170
 
171
  st.markdown(" ")
172
 
173
+ images_dior = [os.path.join("data/dior_show/images",url) for url in os.listdir("data/dior_show/images") if url != "results"]
174
  columns_img = st.columns(4)
175
  for img, col in zip(images_dior,columns_img):
176
  with col:
177
  st.image(img)
178
 
 
179
  st.markdown(" ")
180
 
181
 
182
+ st.markdown("### About the model πŸ“š")
183
+ st.markdown("""The object detection model was trained specifically to **detect clothing items** on images. <br>
184
+ It is able to detect <b>46</b> different types of clothing items.""", unsafe_allow_html=True)
185
+
186
+ colors = ["#8ef", "#faa", "#afa", "#fea", "#8ef","#afa"]*7 + ["#8ef", "#faa", "#afa", "#fea"]
187
+
188
+ cats_annotated = [(g,"","#afa") for g in cats]
189
+ annotated_text([cats_annotated])
190
+
191
+ # st.markdown("""**Here are the 'objects' the model is able to detect**: <br>
192
+ # 'shirt, blouse', 'top, t-shirt, sweatshirt', 'sweater', 'cardigan', 'jacket', 'vest', 'pants', 'shorts', 'skirt',
193
+ # 'coat', 'dress', 'jumpsuit', 'cape', 'glasses', 'hat', 'headband, head covering, hair accessory', 'tie', 'glove', 'watch',
194
+ # 'belt', 'leg warmer', 'tights, stockings', 'sock', 'shoe', 'bag, wallet', 'scarf', 'umbrella', 'hood', 'collar', 'lapel',
195
+ # 'epaulette', 'sleeve', 'pocket', 'neckline', 'buckle', 'zipper', 'applique', 'bead', 'bow', 'flower', 'fringe', 'ribbon', 'rivet',
196
+ # 'ruffle', 'sequin', 'tassel'""", unsafe_allow_html=True)
197
+
198
+ st.markdown("Credits: https://huggingface.co/valentinafeve/yolos-fashionpedia")
199
+ st.markdown("")
200
+ st.markdown("")
201
+
202
+
203
+
204
  ############## SELECT AN IMAGE ###############
205
 
206
+ st.markdown("### Select an image πŸ–ΌοΈ")
207
  #st.markdown("""**Select an image that you wish to run the Object Detection model on.**""")
208
 
 
209
  image_ = None
210
+ fashion_images_path = r"data/dior_show/images"
211
+ list_images = os.listdir(fashion_images_path)
212
+ image_ = st.selectbox("Select the image you wish to run the model on", list_images)
213
+ image_ = os.path.join(fashion_images_path, image_)#label_visibility="collapsed")
214
+
215
+ # image_ = None
216
+ # select_image_box = st.radio(
217
+ # "**Select the image you wish to run the model on**",
218
+ # ["Choose an existing image", "Load your own image"],
219
+ # index=None,)# #label_visibility="collapsed")
220
+
221
+ # if select_image_box == "Choose an existing image":
222
+ # fashion_images_path = r"data/dior_show/images"
223
+ # list_images = os.listdir(fashion_images_path)
224
+ # image_ = st.selectbox("", list_images, label_visibility="collapsed")
225
 
226
+ # if image_ is not None:
227
+ # image_ = os.path.join(fashion_images_path,image_)
228
+ # st.markdown("You've selected the following image:")
229
+ # st.image(image_, width=300)
230
+
231
+ # elif select_image_box == "Load your own image":
232
+ # image_ = st.file_uploader("Load an image here",
233
+ # key="OD_dior", type=['jpg','jpeg','png'], label_visibility="collapsed")
234
 
235
+ # st.warning("""**Note**: The model tends to perform better with images of people/clothing items facing forward.
236
+ # Choose this type of image if you want optimal results.""")
237
+ # st.warning("""**Note:** The model was trained to detect clothing items on a single person.
238
+ # If your image contains more than one person, the model won't detect the items of the other persons.""")
239
 
240
+ # if image_ is not None:
241
+ # st.image(Image.open(image_), width=300)
242
 
243
 
244
  st.markdown(" ")
 
248
 
249
  ########## SELECT AN ELEMENT TO DETECT ##################
250
 
 
 
 
251
 
252
  dict_cats = dict(zip(np.arange(len(cats)), cats))
253
 
254
+ # st.markdown("#### Choose the elements you want to detect πŸ‘‰")
255
 
256
+ # # Select one or more elements to detect
257
+ # container = st.container()
258
+ # selected_options = None
259
+ # all = st.checkbox("Select all")
260
 
261
+ # if all:
262
+ # selected_options = container.multiselect("**Select one or more items**", cats, cats)
263
+ # else:
264
+ # selected_options = container.multiselect("**Select one or more items**", cats)
265
 
266
  #cats = selected_options
267
+ selected_options = cats
268
  dict_cats_final = {key:value for (key,value) in dict_cats.items() if value in selected_options}
269
 
270
 
271
+ # st.markdown(" ")
272
+ # st.markdown(" ")
273
 
274
 
275
 
276
  ############## SELECT A THRESHOLD ###############
277
 
278
+ st.markdown("### Define a threshold for predictions πŸ”Ž")
279
+ st.markdown("""This section allows you to control how confident you want your model to be with its predictions. <br>
280
+ Objects that are given a lower score than the chosen threshold will be ignored in the final results.""", unsafe_allow_html=True)
 
281
 
282
+ st.markdown(" Below is an example of probability scores given by object detection models for each element detected.")
283
 
284
 
285
+ st.image("images/probability_od.png", caption="Example with bounding boxes and probability scores given by object detection models")
 
 
286
 
287
  st.markdown(" ")
288
 
 
293
 
294
  threshold = st.slider('**Select a threshold**', min_value=0.5, step=0.05, max_value=1.0, value=0.75, label_visibility="collapsed")
295
 
296
+
297
+ # if threshold < 0.6:
298
+ # st.error("""**Warning**: Selecting a low threshold (below 0.6) could lead the model to make errors and detect too many objects.""")
299
 
300
  st.write("You've selected a threshold at", threshold)
301
  st.markdown(" ")
302
 
303
 
304
+
305
+ pickle_file_path = r"data/dior_show/results"
306
+
307
+
308
  ############# RUN MODEL ################
309
 
310
  run_model = st.button("**Run the model**", type="primary")
 
313
  if image_ != None and selected_options != None and threshold!= None:
314
  with st.spinner('Wait for it...'):
315
  ## SELECT IMAGE
316
+ #st.write(image_)
317
  image = Image.open(image_)
318
  image = fix_channels(ToTensor()(image))
319
 
320
  ## LOAD OBJECT DETECTION MODEL
321
  FEATURE_EXTRACTOR_PATH = "hustvl/yolos-small"
322
  MODEL_PATH = "valentinafeve/yolos-fashionpedia"
323
+ # feature_extractor, model = load_model(FEATURE_EXTRACTOR_PATH, MODEL_PATH)
 
 
324
 
325
+ # # RUN MODEL ON IMAGE
326
+ # inputs = feature_extractor(images=image, return_tensors="pt")
327
+ # outputs = model(**inputs)
328
+
329
+ # Save results
330
+ # pickle_file_path = r"data/dior_show/results"
331
+ # image_name = image_.split('\\')[1][:5]
332
+ # with open(os.path.join(pickle_file_path, f"{image_name}_results.pkl"), 'wb') as file:
333
+ # pickle.dump(outputs, file)
334
+
335
+ image_name = image_.split('\\')[1][:5]
336
+ path_load_pickle = os.path.join(pickle_file_path, f"{image_name}_results.pkl")
337
+ with open(path_load_pickle, 'rb') as pickle_file:
338
+ outputs = pickle.load(pickle_file)
339
+
340
  probas, keep = return_probas(outputs, threshold)
341
 
342
  st.markdown("#### See the results β˜‘οΈ")