ugaray96 commited on
Commit
0f734ea
·
unverified ·
1 Parent(s): b3450a7

Refactor and improve model, app, and training components

Browse files

- Update dependencies in requirements.txt with pinned versions
- Enhance device handling in model and dataset classes
- Improve Streamlit app caching and error handling
- Optimize training and retraining procedures
- Add support for MPS device in model selection
- Update .gitignore to include output directory

Signed-off-by: Unai Garay <unaigaraymaestre@gmail.com>

Files changed (6) hide show
  1. .gitignore +2 -1
  2. app.py +44 -26
  3. dataset.py +17 -10
  4. model.py +77 -39
  5. requirements.txt +14 -11
  6. train.py +84 -49
.gitignore CHANGED
@@ -2,4 +2,5 @@ feedback*
2
  new_model/
3
  __pycache__/
4
  data/
5
- events.out.*
 
 
2
  new_model/
3
  __pycache__/
4
  data/
5
+ events.out.*
6
+ output/
app.py CHANGED
@@ -1,27 +1,30 @@
1
  import os
 
 
2
  import streamlit as st
3
  from PIL import Image
4
- import requests
5
- import io
6
- import time
7
  from model import ViTForImageClassification
8
 
9
  st.set_page_config(
10
- page_title="Grocery Classifier",
11
- page_icon="interface/shopping-cart.png",
12
- initial_sidebar_state="expanded"
13
  )
14
 
15
- @st.cache()
 
16
  def load_model():
17
  with st.spinner("Loading model"):
18
- model = ViTForImageClassification('google/vit-base-patch16-224')
19
- model.load('model/')
20
  return model
21
-
 
22
  model = load_model()
23
  feedback_path = "feedback"
24
 
 
25
  def predict(image):
26
  print("Predicting...")
27
  # Load using PIL
@@ -29,21 +32,24 @@ def predict(image):
29
 
30
  prediction, confidence = model.predict(image)
31
 
32
- return {'prediction': prediction[0], 'confidence': round(confidence[0], 3)}, image
 
33
 
34
  def submit_feedback(correct_label, image):
35
  folder_path = feedback_path + "/" + correct_label + "/"
36
  os.makedirs(folder_path, exist_ok=True)
37
  image.save(folder_path + correct_label + "_" + str(int(time.time())) + ".png")
38
-
 
39
  def retrain_from_feedback():
40
  model.retrain_from_path(feedback_path, remove_path=True)
41
 
 
42
  def main():
43
  labels = set(list(model.label_encoder.classes_))
44
 
45
  st.title("🍇 Grocery Classifier 🥑")
46
-
47
  if labels is None:
48
  st.warning("Received error from server, labels could not be retrieved")
49
  else:
@@ -54,37 +60,49 @@ def main():
54
  st.image(image_file)
55
 
56
  st.subheader("Classification")
57
-
58
  if st.button("Predict"):
59
- st.session_state['response_json'], st.session_state['image'] = predict(image_file)
 
 
60
 
61
- if 'response_json' in st.session_state and st.session_state['response_json'] is not None:
 
 
 
62
  # Show the result
63
- st.markdown(f"**Prediction:** {st.session_state['response_json']['prediction']}")
64
- st.markdown(f"**Confidence:** {st.session_state['response_json']['confidence']}")
65
-
 
 
 
 
66
  # User feedback
67
  st.subheader("User Feedback")
68
- st.markdown("If this prediction was incorrect, please select below the correct label")
 
 
69
  correct_labels = labels.copy()
70
- correct_labels.remove(st.session_state['response_json']["prediction"])
71
  correct_label = st.selectbox("Correct label", correct_labels)
72
  if st.button("Submit"):
73
  # Save feedback
74
  try:
75
- submit_feedback(correct_label, st.session_state['image'])
76
  st.success("Feedback submitted")
77
  except Exception as e:
78
  st.error("Feedback could not be submitted. Error: {}".format(e))
79
-
80
  # Retrain from feedback
81
  if st.button("Retrain from feedback"):
82
  try:
83
- with st.spinner('Retraining...'):
84
  retrain_from_feedback()
85
  st.success("Model retrained")
86
  st.balloons()
87
  except Exception as e:
88
  st.warning("Model could not be retrained. Error: {}".format(e))
89
-
90
- main()
 
 
1
  import os
2
+ import time
3
+
4
  import streamlit as st
5
  from PIL import Image
6
+
 
 
7
  from model import ViTForImageClassification
8
 
9
  st.set_page_config(
10
+ page_title="Grocery Classifier",
11
+ page_icon="interface/shopping-cart.png",
12
+ initial_sidebar_state="expanded",
13
  )
14
 
15
+
16
+ @st.cache_resource()
17
  def load_model():
18
  with st.spinner("Loading model"):
19
+ model = ViTForImageClassification("google/vit-base-patch16-224")
20
+ model.load("model/")
21
  return model
22
+
23
+
24
  model = load_model()
25
  feedback_path = "feedback"
26
 
27
+
28
  def predict(image):
29
  print("Predicting...")
30
  # Load using PIL
 
32
 
33
  prediction, confidence = model.predict(image)
34
 
35
+ return {"prediction": prediction[0], "confidence": round(confidence[0], 3)}, image
36
+
37
 
38
  def submit_feedback(correct_label, image):
39
  folder_path = feedback_path + "/" + correct_label + "/"
40
  os.makedirs(folder_path, exist_ok=True)
41
  image.save(folder_path + correct_label + "_" + str(int(time.time())) + ".png")
42
+
43
+
44
  def retrain_from_feedback():
45
  model.retrain_from_path(feedback_path, remove_path=True)
46
 
47
+
48
  def main():
49
  labels = set(list(model.label_encoder.classes_))
50
 
51
  st.title("🍇 Grocery Classifier 🥑")
52
+
53
  if labels is None:
54
  st.warning("Received error from server, labels could not be retrieved")
55
  else:
 
60
  st.image(image_file)
61
 
62
  st.subheader("Classification")
63
+
64
  if st.button("Predict"):
65
+ st.session_state["response_json"], st.session_state["image"] = predict(
66
+ image_file
67
+ )
68
 
69
+ if (
70
+ "response_json" in st.session_state
71
+ and st.session_state["response_json"] is not None
72
+ ):
73
  # Show the result
74
+ st.markdown(
75
+ f"**Prediction:** {st.session_state['response_json']['prediction']}"
76
+ )
77
+ st.markdown(
78
+ f"**Confidence:** {st.session_state['response_json']['confidence']}"
79
+ )
80
+
81
  # User feedback
82
  st.subheader("User Feedback")
83
+ st.markdown(
84
+ "If this prediction was incorrect, please select below the correct label"
85
+ )
86
  correct_labels = labels.copy()
87
+ correct_labels.remove(st.session_state["response_json"]["prediction"])
88
  correct_label = st.selectbox("Correct label", correct_labels)
89
  if st.button("Submit"):
90
  # Save feedback
91
  try:
92
+ submit_feedback(correct_label, st.session_state["image"])
93
  st.success("Feedback submitted")
94
  except Exception as e:
95
  st.error("Feedback could not be submitted. Error: {}".format(e))
96
+
97
  # Retrain from feedback
98
  if st.button("Retrain from feedback"):
99
  try:
100
+ with st.spinner("Retraining..."):
101
  retrain_from_feedback()
102
  st.success("Model retrained")
103
  st.balloons()
104
  except Exception as e:
105
  st.warning("Model could not be retrained. Error: {}".format(e))
106
+
107
+
108
+ main()
dataset.py CHANGED
@@ -1,27 +1,34 @@
1
  import torch
2
 
 
3
  class RetailDataset(torch.utils.data.Dataset):
4
- def __init__(self, data, labels=None, transform=None):
5
  self.data = data
6
  self.labels = labels
7
  self.num_classes = len(set(labels))
8
  self.transform = transform
 
9
 
10
  def __getitem__(self, idx):
11
- item = {key: val[idx].detach().clone() for key, val in self.data.items()}
12
- item['labels'] = self.labels[idx]
 
 
 
13
  return item
14
 
15
  def __len__(self):
16
  return len(self.labels)
17
 
18
  def __repr__(self):
19
- return 'RetailDataset'
20
 
21
  def __str__(self):
22
- return str({
23
- 'data': self.data['pixel_values'].shape,
24
- 'labels': self.labels.shape,
25
- 'num_classes': self.num_classes,
26
- 'num_samples': len(self.labels)
27
- })
 
 
 
1
  import torch
2
 
3
+
4
  class RetailDataset(torch.utils.data.Dataset):
5
+ def __init__(self, data, labels=None, transform=None, device=None):
6
  self.data = data
7
  self.labels = labels
8
  self.num_classes = len(set(labels))
9
  self.transform = transform
10
+ self.device = device if device is not None else torch.device("cpu")
11
 
12
  def __getitem__(self, idx):
13
+ item = {
14
+ key: torch.tensor(val[idx].detach().clone(), device=self.device)
15
+ for key, val in self.data.items()
16
+ }
17
+ item["labels"] = torch.tensor(self.labels[idx], device=self.device)
18
  return item
19
 
20
  def __len__(self):
21
  return len(self.labels)
22
 
23
  def __repr__(self):
24
+ return "RetailDataset"
25
 
26
  def __str__(self):
27
+ return str(
28
+ {
29
+ "data": self.data["pixel_values"].shape,
30
+ "labels": self.labels.shape,
31
+ "num_classes": self.num_classes,
32
+ "num_samples": len(self.labels),
33
+ }
34
+ )
model.py CHANGED
@@ -1,26 +1,39 @@
 
 
1
  import shutil
2
  import time
 
3
  import numpy as np
4
- from tqdm import tqdm
5
- from transformers import ViTModel, ViTFeatureExtractor
6
- from transformers.modeling_outputs import SequenceClassifierOutput
7
- import torch.nn as nn
8
  import torch
 
9
  from PIL import Image
10
- import logging
11
- import os
12
  from sklearn.preprocessing import LabelEncoder
 
 
 
 
13
  from train import (
14
- re_training, metric, f1_score,
15
- classification_report
 
16
  )
17
 
18
- data_path = os.environ.get('DATA_PATH', "./data")
19
 
20
  logging.basicConfig(level=os.getenv("LOGGER_LEVEL", logging.WARNING))
21
  logger = logging.getLogger(__name__)
22
 
 
23
  class ViTForImageClassification(nn.Module):
 
 
 
 
 
 
 
 
24
  def __init__(self, model_name, num_labels=24, dropout=0.25, image_size=224):
25
  logger.info("Loading model")
26
  super(ViTForImageClassification, self).__init__()
@@ -32,7 +45,8 @@ class ViTForImageClassification(nn.Module):
32
  self.classifier = nn.Linear(self.vit.config.hidden_size, num_labels)
33
  self.num_labels = num_labels
34
  self.label_encoder = LabelEncoder()
35
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
36
  self.model_name = model_name
37
  # To device
38
  self.vit.to(self.device)
@@ -44,7 +58,7 @@ class ViTForImageClassification(nn.Module):
44
  logger.info("Forwarding")
45
  pixel_values = pixel_values.to(self.device)
46
  outputs = self.vit(pixel_values=pixel_values)
47
- output = self.dropout(outputs.last_hidden_state[:,0])
48
  logits = self.classifier(output)
49
 
50
  loss = None
@@ -61,17 +75,21 @@ class ViTForImageClassification(nn.Module):
61
 
62
  def preprocess_image(self, images):
63
  logger.info("Preprocessing images")
64
- return self.feature_extractor(images, return_tensors='pt')
65
 
66
- def predict(self, images, batch_size=32, classes_names=True, return_probabilities=False):
 
 
67
  logger.info("Predicting")
68
  if not isinstance(images, list):
69
  images = [images]
70
  classes_list = []
71
  confidence_list = []
72
- for bs in tqdm(range(0, len(images), batch_size), desc="Preprocessing training images"):
73
- images_batch = [image for image in images[bs:bs+batch_size]]
74
- images_batch = self.preprocess_image(images_batch)['pixel_values']
 
 
75
  sequence_classifier_output = self.forward(images_batch, None)
76
  # Get max prob
77
  probs = sequence_classifier_output.logits.softmax(dim=-1).tolist()
@@ -96,19 +114,23 @@ class ViTForImageClassification(nn.Module):
96
  logger.info("Loading model")
97
  # Load label encoder
98
  # Check if label encoder and model exists
99
- if not os.path.exists(path + "/label_encoder.npy") or not os.path.exists(path + "/model.pt"):
 
 
100
  logger.warning("Label encoder or model not found")
101
  return
102
  self.label_encoder.classes_ = np.load(path + "/label_encoder.npy")
103
  # Reload classifier layer
104
- self.classifier = nn.Linear(self.vit.config.hidden_size, len(self.label_encoder.classes_))
105
-
 
 
106
  self.load_state_dict(torch.load(path + "/model.pt", map_location=self.device))
107
  self.vit.to(self.device)
108
  self.vit.eval()
109
  self.to(self.device)
110
  self.eval()
111
-
112
  def evaluate(self, images, labels):
113
  logger.info("Evaluating")
114
  labels = self.label_encoder.transform(labels)
@@ -117,11 +139,18 @@ class ViTForImageClassification(nn.Module):
117
  # Evaluate
118
  metrics = metric.compute(predictions=y_pred, references=labels)
119
  f1 = f1_score.compute(predictions=y_pred, references=labels, average="macro")
120
- print(classification_report(labels, y_pred, labels=[i for i in range(len(self.label_encoder.classes_))], target_names=self.label_encoder.classes_))
 
 
 
 
 
 
 
121
  print(f"Accuracy: {metrics['accuracy']}")
122
  print(f"F1: {f1}")
123
-
124
- def partial_fit(self, images, labels, save_model_path='new_model', num_epochs=10):
125
  logger.info("Partial fitting")
126
  # Freeze ViT model but last layer
127
  # params = [param for param in self.vit.parameters()]
@@ -135,21 +164,27 @@ class ViTForImageClassification(nn.Module):
135
  self.vit.eval()
136
  self.eval()
137
  self.evaluate(images, labels)
138
-
139
  def __load_from_path(self, path, num_per_label=None):
140
  images = []
141
  labels = []
142
  for label in os.listdir(path):
143
  count = 0
144
  label_folder_path = os.path.join(path, label)
145
- for image_file in tqdm(os.listdir(label_folder_path), desc="Resizing images for label {}".format(label)):
 
 
 
146
  file_path = os.path.join(label_folder_path, image_file)
147
  try:
148
  image = Image.open(file_path)
149
- image_shape = (self.feature_extractor.size, self.feature_extractor.size)
 
 
 
150
  if image.size != image_shape:
151
  image = image.resize(image_shape)
152
- images.append(image.convert('RGB'))
153
  labels.append(label)
154
  count += 1
155
  except Exception as e:
@@ -157,14 +192,16 @@ class ViTForImageClassification(nn.Module):
157
  if num_per_label is not None and count >= num_per_label:
158
  break
159
  return images, labels
160
-
161
- def retrain_from_path(self,
162
- path='./data/feedback',
163
- num_per_label=None,
164
- save_model_path='new_model',
165
- remove_path=False,
166
- num_epochs=10,
167
- save_new_data=data_path + '/new_data'):
 
 
168
  logger.info("Retraining from path")
169
  # Load path
170
  images, labels = self.__load_from_path(path, num_per_label)
@@ -173,19 +210,20 @@ class ViTForImageClassification(nn.Module):
173
  # Save new data
174
  if save_new_data is not None:
175
  logger.info("Saving new data")
176
- for i ,(image, label) in enumerate(zip(images, labels)):
177
  label_path = os.path.join(save_new_data, label)
178
  os.makedirs(label_path, exist_ok=True)
179
- image.save(os.path.join(label_path, str(int(time.time())) + f"_{i}.jpg"))
 
 
180
  # Remove path folder
181
  if remove_path:
182
  logger.info("Removing feedback path")
183
  shutil.rmtree(path)
184
-
185
  def evaluate_from_path(self, path, num_per_label=None):
186
  logger.info("Evaluating from path")
187
  # Load images
188
  images, labels = self.__load_from_path(path, num_per_label)
189
  # Evaluate
190
  self.evaluate(images, labels)
191
-
 
1
+ import logging
2
+ import os
3
  import shutil
4
  import time
5
+
6
  import numpy as np
 
 
 
 
7
  import torch
8
+ import torch.nn as nn
9
  from PIL import Image
10
+ from sklearn.metrics import classification_report
 
11
  from sklearn.preprocessing import LabelEncoder
12
+ from tqdm import tqdm
13
+ from transformers import ViTFeatureExtractor, ViTModel
14
+ from transformers.modeling_outputs import SequenceClassifierOutput
15
+
16
  from train import (
17
+ f1_score,
18
+ metric,
19
+ re_training,
20
  )
21
 
22
+ data_path = os.environ.get("DATA_PATH", "./data")
23
 
24
  logging.basicConfig(level=os.getenv("LOGGER_LEVEL", logging.WARNING))
25
  logger = logging.getLogger(__name__)
26
 
27
+
28
  class ViTForImageClassification(nn.Module):
29
+ @staticmethod
30
+ def get_device():
31
+ if torch.cuda.is_available():
32
+ return torch.device("cuda")
33
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
34
+ return torch.device("mps")
35
+ return torch.device("cpu")
36
+
37
  def __init__(self, model_name, num_labels=24, dropout=0.25, image_size=224):
38
  logger.info("Loading model")
39
  super(ViTForImageClassification, self).__init__()
 
45
  self.classifier = nn.Linear(self.vit.config.hidden_size, num_labels)
46
  self.num_labels = num_labels
47
  self.label_encoder = LabelEncoder()
48
+ self.device = self.get_device()
49
+ logger.info(f"Using device: {self.device}")
50
  self.model_name = model_name
51
  # To device
52
  self.vit.to(self.device)
 
58
  logger.info("Forwarding")
59
  pixel_values = pixel_values.to(self.device)
60
  outputs = self.vit(pixel_values=pixel_values)
61
+ output = self.dropout(outputs.last_hidden_state[:, 0])
62
  logits = self.classifier(output)
63
 
64
  loss = None
 
75
 
76
  def preprocess_image(self, images):
77
  logger.info("Preprocessing images")
78
+ return self.feature_extractor(images, return_tensors="pt")
79
 
80
+ def predict(
81
+ self, images, batch_size=32, classes_names=True, return_probabilities=False
82
+ ):
83
  logger.info("Predicting")
84
  if not isinstance(images, list):
85
  images = [images]
86
  classes_list = []
87
  confidence_list = []
88
+ for bs in tqdm(
89
+ range(0, len(images), batch_size), desc="Preprocessing training images"
90
+ ):
91
+ images_batch = [image for image in images[bs : bs + batch_size]]
92
+ images_batch = self.preprocess_image(images_batch)["pixel_values"]
93
  sequence_classifier_output = self.forward(images_batch, None)
94
  # Get max prob
95
  probs = sequence_classifier_output.logits.softmax(dim=-1).tolist()
 
114
  logger.info("Loading model")
115
  # Load label encoder
116
  # Check if label encoder and model exists
117
+ if not os.path.exists(path + "/label_encoder.npy") or not os.path.exists(
118
+ path + "/model.pt"
119
+ ):
120
  logger.warning("Label encoder or model not found")
121
  return
122
  self.label_encoder.classes_ = np.load(path + "/label_encoder.npy")
123
  # Reload classifier layer
124
+ self.classifier = nn.Linear(
125
+ self.vit.config.hidden_size, len(self.label_encoder.classes_)
126
+ )
127
+
128
  self.load_state_dict(torch.load(path + "/model.pt", map_location=self.device))
129
  self.vit.to(self.device)
130
  self.vit.eval()
131
  self.to(self.device)
132
  self.eval()
133
+
134
  def evaluate(self, images, labels):
135
  logger.info("Evaluating")
136
  labels = self.label_encoder.transform(labels)
 
139
  # Evaluate
140
  metrics = metric.compute(predictions=y_pred, references=labels)
141
  f1 = f1_score.compute(predictions=y_pred, references=labels, average="macro")
142
+ print(
143
+ classification_report(
144
+ labels,
145
+ y_pred,
146
+ labels=[i for i in range(len(self.label_encoder.classes_))],
147
+ target_names=self.label_encoder.classes_,
148
+ )
149
+ )
150
  print(f"Accuracy: {metrics['accuracy']}")
151
  print(f"F1: {f1}")
152
+
153
+ def partial_fit(self, images, labels, save_model_path="new_model", num_epochs=10):
154
  logger.info("Partial fitting")
155
  # Freeze ViT model but last layer
156
  # params = [param for param in self.vit.parameters()]
 
164
  self.vit.eval()
165
  self.eval()
166
  self.evaluate(images, labels)
167
+
168
  def __load_from_path(self, path, num_per_label=None):
169
  images = []
170
  labels = []
171
  for label in os.listdir(path):
172
  count = 0
173
  label_folder_path = os.path.join(path, label)
174
+ for image_file in tqdm(
175
+ os.listdir(label_folder_path),
176
+ desc="Resizing images for label {}".format(label),
177
+ ):
178
  file_path = os.path.join(label_folder_path, image_file)
179
  try:
180
  image = Image.open(file_path)
181
+ image_shape = (
182
+ self.feature_extractor.size,
183
+ self.feature_extractor.size,
184
+ )
185
  if image.size != image_shape:
186
  image = image.resize(image_shape)
187
+ images.append(image.convert("RGB"))
188
  labels.append(label)
189
  count += 1
190
  except Exception as e:
 
192
  if num_per_label is not None and count >= num_per_label:
193
  break
194
  return images, labels
195
+
196
+ def retrain_from_path(
197
+ self,
198
+ path="./data/feedback",
199
+ num_per_label=None,
200
+ save_model_path="new_model",
201
+ remove_path=False,
202
+ num_epochs=10,
203
+ save_new_data=data_path + "/new_data",
204
+ ):
205
  logger.info("Retraining from path")
206
  # Load path
207
  images, labels = self.__load_from_path(path, num_per_label)
 
210
  # Save new data
211
  if save_new_data is not None:
212
  logger.info("Saving new data")
213
+ for i, (image, label) in enumerate(zip(images, labels)):
214
  label_path = os.path.join(save_new_data, label)
215
  os.makedirs(label_path, exist_ok=True)
216
+ image.save(
217
+ os.path.join(label_path, str(int(time.time())) + f"_{i}.jpg")
218
+ )
219
  # Remove path folder
220
  if remove_path:
221
  logger.info("Removing feedback path")
222
  shutil.rmtree(path)
223
+
224
  def evaluate_from_path(self, path, num_per_label=None):
225
  logger.info("Evaluating from path")
226
  # Load images
227
  images, labels = self.__load_from_path(path, num_per_label)
228
  # Evaluate
229
  self.evaluate(images, labels)
 
requirements.txt CHANGED
@@ -1,11 +1,14 @@
1
- Pillow
2
- requests
3
- numpy
4
- transformers
5
- scikit-learn
6
- datasets
7
- streamlit
8
- matplotlib
9
- scikit-image
10
- torch
11
- torchvision
 
 
 
 
1
+ Pillow==10.4.0
2
+ requests==2.32.3
3
+ numpy==1.24.4
4
+ transformers==4.46.3
5
+ scikit-learn==1.3.2
6
+ datasets==3.1.0
7
+ streamlit==1.40.1
8
+ matplotlib==3.7.5
9
+ scikit-image==0.21.0
10
+ torch==2.4.1
11
+ torchvision==0.19.1
12
+ altair==5.4.1
13
+ evaluate==0.4.3
14
+ accelerate==1.0.1
train.py CHANGED
@@ -1,12 +1,10 @@
 
1
  import os
 
2
  import numpy as np
3
- from sklearn.metrics import classification_report
4
- from tqdm import tqdm
5
- import logging
6
- from sklearn.model_selection import train_test_split
7
- from dataset import RetailDataset
8
  from PIL import Image
9
- from datasets import load_metric
10
  from torchvision.transforms import (
11
  CenterCrop,
12
  Compose,
@@ -16,28 +14,38 @@ from torchvision.transforms import (
16
  Resize,
17
  ToTensor,
18
  )
19
- from transformers import Trainer, TrainingArguments, BatchFeature
20
- metric = load_metric("accuracy")
21
- f1_score = load_metric("f1")
 
 
 
 
22
  np.random.seed(42)
23
 
24
  logging.basicConfig(level=os.getenv("LOGGER_LEVEL", logging.WARNING))
25
  logger = logging.getLogger(__name__)
26
-
27
- def prepare_dataset(images,
28
- labels,
29
- model,
30
- test_size=.2,
31
- train_transform=None,
32
- val_transform=None,
33
- batch_size=512):
 
 
 
34
  logger.info("Preparing dataset")
35
  # Split the dataset in train and test
36
  try:
37
- images_train, images_test, labels_train, labels_test = \
38
- train_test_split(images, labels, test_size=test_size)
 
39
  except ValueError:
40
- logger.warning("Could not split dataset. Using all data for training and testing")
 
 
41
  images_train = images
42
  labels_train = labels
43
  images_test = images
@@ -46,14 +54,24 @@ def prepare_dataset(images,
46
  # Preprocess images using model feature extractor
47
  images_train_prep = []
48
  images_test_prep = []
49
- for bs in tqdm(range(0, len(images_train), batch_size), desc="Preprocessing training images"):
50
- images_train_batch = [Image.fromarray(np.array(image)) for image in images_train[bs:bs+batch_size]]
 
 
 
 
 
51
  images_train_batch = model.preprocess_image(images_train_batch)
52
- images_train_prep.extend(images_train_batch['pixel_values'])
53
- for bs in tqdm(range(0, len(images_test), batch_size), desc="Preprocessing test images"):
54
- images_test_batch = [Image.fromarray(np.array(image)) for image in images_test[bs:bs+batch_size]]
 
 
 
 
 
55
  images_test_batch = model.preprocess_image(images_test_batch)
56
- images_test_prep.extend(images_test_batch['pixel_values'])
57
 
58
  # Create BatchFeatures
59
  images_train_prep = {"pixel_values": images_train_prep}
@@ -61,50 +79,67 @@ def prepare_dataset(images,
61
  images_test_prep = {"pixel_values": images_test_prep}
62
  test_batch_features = BatchFeature(data=images_test_prep)
63
 
64
- # Create the datasets
65
- train_dataset = RetailDataset(train_batch_features, labels_train, train_transform)
66
- test_dataset = RetailDataset(test_batch_features, labels_test, val_transform)
 
 
 
 
67
  logger.info("Train dataset: %d images", len(labels_train))
68
  logger.info("Test dataset: %d images", len(labels_test))
69
  return train_dataset, test_dataset
70
 
71
- def re_training(images, labels, _model, save_model_path='new_model', num_epochs=10):
 
72
  global model
73
  model = _model
74
  labels = model.label_encoder.transform(labels)
75
- normalize = Normalize(mean=model.feature_extractor.image_mean, std=model.feature_extractor.image_std)
 
 
 
76
  def train_transforms(batch):
77
- return Compose([
78
- RandomResizedCrop(model.feature_extractor.size),
79
- RandomHorizontalFlip(),
80
- ToTensor(),
81
- normalize,
82
- ])(batch)
 
 
83
 
84
  def val_transforms(batch):
85
- return Compose([
86
- Resize(model.feature_extractor.size),
87
- CenterCrop(model.feature_extractor.size),
88
- ToTensor(),
89
- normalize,
90
- ])(batch)
 
 
 
91
  train_dataset, test_dataset = prepare_dataset(
92
- images, labels, model, .2, train_transforms, val_transforms)
 
 
93
  trainer = Trainer(
94
  model=model,
95
  args=TrainingArguments(
96
- output_dir='output',
97
  overwrite_output_dir=True,
98
  num_train_epochs=num_epochs,
99
  per_device_train_batch_size=32,
100
  gradient_accumulation_steps=1,
101
  learning_rate=0.000001,
102
  weight_decay=0.01,
103
- evaluation_strategy='steps',
104
  eval_steps=1000,
105
- save_steps=3000),
 
 
106
  train_dataset=train_dataset,
107
- eval_dataset=test_dataset
108
  )
109
  trainer.train()
110
- model.save(save_model_path)
 
1
+ import logging
2
  import os
3
+
4
  import numpy as np
5
+ from evaluate import load
 
 
 
 
6
  from PIL import Image
7
+ from sklearn.model_selection import train_test_split
8
  from torchvision.transforms import (
9
  CenterCrop,
10
  Compose,
 
14
  Resize,
15
  ToTensor,
16
  )
17
+ from tqdm import tqdm
18
+ from transformers import BatchFeature, Trainer, TrainingArguments
19
+
20
+ from dataset import RetailDataset
21
+
22
+ metric = load("accuracy")
23
+ f1_score = load("f1")
24
  np.random.seed(42)
25
 
26
  logging.basicConfig(level=os.getenv("LOGGER_LEVEL", logging.WARNING))
27
  logger = logging.getLogger(__name__)
28
+
29
+
30
+ def prepare_dataset(
31
+ images,
32
+ labels,
33
+ model,
34
+ test_size=0.2,
35
+ train_transform=None,
36
+ val_transform=None,
37
+ batch_size=512,
38
+ ):
39
  logger.info("Preparing dataset")
40
  # Split the dataset in train and test
41
  try:
42
+ images_train, images_test, labels_train, labels_test = train_test_split(
43
+ images, labels, test_size=test_size
44
+ )
45
  except ValueError:
46
+ logger.warning(
47
+ "Could not split dataset. Using all data for training and testing"
48
+ )
49
  images_train = images
50
  labels_train = labels
51
  images_test = images
 
54
  # Preprocess images using model feature extractor
55
  images_train_prep = []
56
  images_test_prep = []
57
+ for bs in tqdm(
58
+ range(0, len(images_train), batch_size), desc="Preprocessing training images"
59
+ ):
60
+ images_train_batch = [
61
+ Image.fromarray(np.array(image))
62
+ for image in images_train[bs : bs + batch_size]
63
+ ]
64
  images_train_batch = model.preprocess_image(images_train_batch)
65
+ images_train_prep.extend(images_train_batch["pixel_values"])
66
+ for bs in tqdm(
67
+ range(0, len(images_test), batch_size), desc="Preprocessing test images"
68
+ ):
69
+ images_test_batch = [
70
+ Image.fromarray(np.array(image))
71
+ for image in images_test[bs : bs + batch_size]
72
+ ]
73
  images_test_batch = model.preprocess_image(images_test_batch)
74
+ images_test_prep.extend(images_test_batch["pixel_values"])
75
 
76
  # Create BatchFeatures
77
  images_train_prep = {"pixel_values": images_train_prep}
 
79
  images_test_prep = {"pixel_values": images_test_prep}
80
  test_batch_features = BatchFeature(data=images_test_prep)
81
 
82
+ # Create the datasets with proper device
83
+ train_dataset = RetailDataset(
84
+ train_batch_features, labels_train, train_transform, device=model.device
85
+ )
86
+ test_dataset = RetailDataset(
87
+ test_batch_features, labels_test, val_transform, device=model.device
88
+ )
89
  logger.info("Train dataset: %d images", len(labels_train))
90
  logger.info("Test dataset: %d images", len(labels_test))
91
  return train_dataset, test_dataset
92
 
93
+
94
+ def re_training(images, labels, _model, save_model_path="new_model", num_epochs=10):
95
  global model
96
  model = _model
97
  labels = model.label_encoder.transform(labels)
98
+ normalize = Normalize(
99
+ mean=model.feature_extractor.image_mean, std=model.feature_extractor.image_std
100
+ )
101
+
102
  def train_transforms(batch):
103
+ return Compose(
104
+ [
105
+ RandomResizedCrop(model.feature_extractor.size),
106
+ RandomHorizontalFlip(),
107
+ ToTensor(),
108
+ normalize,
109
+ ]
110
+ )(batch)
111
 
112
  def val_transforms(batch):
113
+ return Compose(
114
+ [
115
+ Resize(model.feature_extractor.size),
116
+ CenterCrop(model.feature_extractor.size),
117
+ ToTensor(),
118
+ normalize,
119
+ ]
120
+ )(batch)
121
+
122
  train_dataset, test_dataset = prepare_dataset(
123
+ images, labels, model, 0.2, train_transforms, val_transforms
124
+ )
125
+
126
  trainer = Trainer(
127
  model=model,
128
  args=TrainingArguments(
129
+ output_dir="output",
130
  overwrite_output_dir=True,
131
  num_train_epochs=num_epochs,
132
  per_device_train_batch_size=32,
133
  gradient_accumulation_steps=1,
134
  learning_rate=0.000001,
135
  weight_decay=0.01,
136
+ eval_strategy="steps",
137
  eval_steps=1000,
138
+ save_steps=3000,
139
+ use_cpu=model.device.type == "cpu", # Only force CPU if that's our device
140
+ ),
141
  train_dataset=train_dataset,
142
+ eval_dataset=test_dataset,
143
  )
144
  trainer.train()
145
+ model.save(save_model_path)