Spaces:
Sleeping
Sleeping
Refactor and improve model, app, and training components
Browse files- Update dependencies in requirements.txt with pinned versions
- Enhance device handling in model and dataset classes
- Improve Streamlit app caching and error handling
- Optimize training and retraining procedures
- Add support for MPS device in model selection
- Update .gitignore to include output directory
Signed-off-by: Unai Garay <unaigaraymaestre@gmail.com>
- .gitignore +2 -1
- app.py +44 -26
- dataset.py +17 -10
- model.py +77 -39
- requirements.txt +14 -11
- train.py +84 -49
.gitignore
CHANGED
@@ -2,4 +2,5 @@ feedback*
|
|
2 |
new_model/
|
3 |
__pycache__/
|
4 |
data/
|
5 |
-
events.out.*
|
|
|
|
2 |
new_model/
|
3 |
__pycache__/
|
4 |
data/
|
5 |
+
events.out.*
|
6 |
+
output/
|
app.py
CHANGED
@@ -1,27 +1,30 @@
|
|
1 |
import os
|
|
|
|
|
2 |
import streamlit as st
|
3 |
from PIL import Image
|
4 |
-
|
5 |
-
import io
|
6 |
-
import time
|
7 |
from model import ViTForImageClassification
|
8 |
|
9 |
st.set_page_config(
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
)
|
14 |
|
15 |
-
|
|
|
16 |
def load_model():
|
17 |
with st.spinner("Loading model"):
|
18 |
-
model = ViTForImageClassification(
|
19 |
-
model.load(
|
20 |
return model
|
21 |
-
|
|
|
22 |
model = load_model()
|
23 |
feedback_path = "feedback"
|
24 |
|
|
|
25 |
def predict(image):
|
26 |
print("Predicting...")
|
27 |
# Load using PIL
|
@@ -29,21 +32,24 @@ def predict(image):
|
|
29 |
|
30 |
prediction, confidence = model.predict(image)
|
31 |
|
32 |
-
return {
|
|
|
33 |
|
34 |
def submit_feedback(correct_label, image):
|
35 |
folder_path = feedback_path + "/" + correct_label + "/"
|
36 |
os.makedirs(folder_path, exist_ok=True)
|
37 |
image.save(folder_path + correct_label + "_" + str(int(time.time())) + ".png")
|
38 |
-
|
|
|
39 |
def retrain_from_feedback():
|
40 |
model.retrain_from_path(feedback_path, remove_path=True)
|
41 |
|
|
|
42 |
def main():
|
43 |
labels = set(list(model.label_encoder.classes_))
|
44 |
|
45 |
st.title("🍇 Grocery Classifier 🥑")
|
46 |
-
|
47 |
if labels is None:
|
48 |
st.warning("Received error from server, labels could not be retrieved")
|
49 |
else:
|
@@ -54,37 +60,49 @@ def main():
|
|
54 |
st.image(image_file)
|
55 |
|
56 |
st.subheader("Classification")
|
57 |
-
|
58 |
if st.button("Predict"):
|
59 |
-
st.session_state[
|
|
|
|
|
60 |
|
61 |
-
if
|
|
|
|
|
|
|
62 |
# Show the result
|
63 |
-
st.markdown(
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
66 |
# User feedback
|
67 |
st.subheader("User Feedback")
|
68 |
-
st.markdown(
|
|
|
|
|
69 |
correct_labels = labels.copy()
|
70 |
-
correct_labels.remove(st.session_state[
|
71 |
correct_label = st.selectbox("Correct label", correct_labels)
|
72 |
if st.button("Submit"):
|
73 |
# Save feedback
|
74 |
try:
|
75 |
-
submit_feedback(correct_label, st.session_state[
|
76 |
st.success("Feedback submitted")
|
77 |
except Exception as e:
|
78 |
st.error("Feedback could not be submitted. Error: {}".format(e))
|
79 |
-
|
80 |
# Retrain from feedback
|
81 |
if st.button("Retrain from feedback"):
|
82 |
try:
|
83 |
-
with st.spinner(
|
84 |
retrain_from_feedback()
|
85 |
st.success("Model retrained")
|
86 |
st.balloons()
|
87 |
except Exception as e:
|
88 |
st.warning("Model could not be retrained. Error: {}".format(e))
|
89 |
-
|
90 |
-
|
|
|
|
1 |
import os
|
2 |
+
import time
|
3 |
+
|
4 |
import streamlit as st
|
5 |
from PIL import Image
|
6 |
+
|
|
|
|
|
7 |
from model import ViTForImageClassification
|
8 |
|
9 |
st.set_page_config(
|
10 |
+
page_title="Grocery Classifier",
|
11 |
+
page_icon="interface/shopping-cart.png",
|
12 |
+
initial_sidebar_state="expanded",
|
13 |
)
|
14 |
|
15 |
+
|
16 |
+
@st.cache_resource()
|
17 |
def load_model():
|
18 |
with st.spinner("Loading model"):
|
19 |
+
model = ViTForImageClassification("google/vit-base-patch16-224")
|
20 |
+
model.load("model/")
|
21 |
return model
|
22 |
+
|
23 |
+
|
24 |
model = load_model()
|
25 |
feedback_path = "feedback"
|
26 |
|
27 |
+
|
28 |
def predict(image):
|
29 |
print("Predicting...")
|
30 |
# Load using PIL
|
|
|
32 |
|
33 |
prediction, confidence = model.predict(image)
|
34 |
|
35 |
+
return {"prediction": prediction[0], "confidence": round(confidence[0], 3)}, image
|
36 |
+
|
37 |
|
38 |
def submit_feedback(correct_label, image):
|
39 |
folder_path = feedback_path + "/" + correct_label + "/"
|
40 |
os.makedirs(folder_path, exist_ok=True)
|
41 |
image.save(folder_path + correct_label + "_" + str(int(time.time())) + ".png")
|
42 |
+
|
43 |
+
|
44 |
def retrain_from_feedback():
|
45 |
model.retrain_from_path(feedback_path, remove_path=True)
|
46 |
|
47 |
+
|
48 |
def main():
|
49 |
labels = set(list(model.label_encoder.classes_))
|
50 |
|
51 |
st.title("🍇 Grocery Classifier 🥑")
|
52 |
+
|
53 |
if labels is None:
|
54 |
st.warning("Received error from server, labels could not be retrieved")
|
55 |
else:
|
|
|
60 |
st.image(image_file)
|
61 |
|
62 |
st.subheader("Classification")
|
63 |
+
|
64 |
if st.button("Predict"):
|
65 |
+
st.session_state["response_json"], st.session_state["image"] = predict(
|
66 |
+
image_file
|
67 |
+
)
|
68 |
|
69 |
+
if (
|
70 |
+
"response_json" in st.session_state
|
71 |
+
and st.session_state["response_json"] is not None
|
72 |
+
):
|
73 |
# Show the result
|
74 |
+
st.markdown(
|
75 |
+
f"**Prediction:** {st.session_state['response_json']['prediction']}"
|
76 |
+
)
|
77 |
+
st.markdown(
|
78 |
+
f"**Confidence:** {st.session_state['response_json']['confidence']}"
|
79 |
+
)
|
80 |
+
|
81 |
# User feedback
|
82 |
st.subheader("User Feedback")
|
83 |
+
st.markdown(
|
84 |
+
"If this prediction was incorrect, please select below the correct label"
|
85 |
+
)
|
86 |
correct_labels = labels.copy()
|
87 |
+
correct_labels.remove(st.session_state["response_json"]["prediction"])
|
88 |
correct_label = st.selectbox("Correct label", correct_labels)
|
89 |
if st.button("Submit"):
|
90 |
# Save feedback
|
91 |
try:
|
92 |
+
submit_feedback(correct_label, st.session_state["image"])
|
93 |
st.success("Feedback submitted")
|
94 |
except Exception as e:
|
95 |
st.error("Feedback could not be submitted. Error: {}".format(e))
|
96 |
+
|
97 |
# Retrain from feedback
|
98 |
if st.button("Retrain from feedback"):
|
99 |
try:
|
100 |
+
with st.spinner("Retraining..."):
|
101 |
retrain_from_feedback()
|
102 |
st.success("Model retrained")
|
103 |
st.balloons()
|
104 |
except Exception as e:
|
105 |
st.warning("Model could not be retrained. Error: {}".format(e))
|
106 |
+
|
107 |
+
|
108 |
+
main()
|
dataset.py
CHANGED
@@ -1,27 +1,34 @@
|
|
1 |
import torch
|
2 |
|
|
|
3 |
class RetailDataset(torch.utils.data.Dataset):
|
4 |
-
def __init__(self, data, labels=None, transform=None):
|
5 |
self.data = data
|
6 |
self.labels = labels
|
7 |
self.num_classes = len(set(labels))
|
8 |
self.transform = transform
|
|
|
9 |
|
10 |
def __getitem__(self, idx):
|
11 |
-
item = {
|
12 |
-
|
|
|
|
|
|
|
13 |
return item
|
14 |
|
15 |
def __len__(self):
|
16 |
return len(self.labels)
|
17 |
|
18 |
def __repr__(self):
|
19 |
-
return
|
20 |
|
21 |
def __str__(self):
|
22 |
-
return str(
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
1 |
import torch
|
2 |
|
3 |
+
|
4 |
class RetailDataset(torch.utils.data.Dataset):
|
5 |
+
def __init__(self, data, labels=None, transform=None, device=None):
|
6 |
self.data = data
|
7 |
self.labels = labels
|
8 |
self.num_classes = len(set(labels))
|
9 |
self.transform = transform
|
10 |
+
self.device = device if device is not None else torch.device("cpu")
|
11 |
|
12 |
def __getitem__(self, idx):
|
13 |
+
item = {
|
14 |
+
key: torch.tensor(val[idx].detach().clone(), device=self.device)
|
15 |
+
for key, val in self.data.items()
|
16 |
+
}
|
17 |
+
item["labels"] = torch.tensor(self.labels[idx], device=self.device)
|
18 |
return item
|
19 |
|
20 |
def __len__(self):
|
21 |
return len(self.labels)
|
22 |
|
23 |
def __repr__(self):
|
24 |
+
return "RetailDataset"
|
25 |
|
26 |
def __str__(self):
|
27 |
+
return str(
|
28 |
+
{
|
29 |
+
"data": self.data["pixel_values"].shape,
|
30 |
+
"labels": self.labels.shape,
|
31 |
+
"num_classes": self.num_classes,
|
32 |
+
"num_samples": len(self.labels),
|
33 |
+
}
|
34 |
+
)
|
model.py
CHANGED
@@ -1,26 +1,39 @@
|
|
|
|
|
|
1 |
import shutil
|
2 |
import time
|
|
|
3 |
import numpy as np
|
4 |
-
from tqdm import tqdm
|
5 |
-
from transformers import ViTModel, ViTFeatureExtractor
|
6 |
-
from transformers.modeling_outputs import SequenceClassifierOutput
|
7 |
-
import torch.nn as nn
|
8 |
import torch
|
|
|
9 |
from PIL import Image
|
10 |
-
import
|
11 |
-
import os
|
12 |
from sklearn.preprocessing import LabelEncoder
|
|
|
|
|
|
|
|
|
13 |
from train import (
|
14 |
-
|
15 |
-
|
|
|
16 |
)
|
17 |
|
18 |
-
data_path = os.environ.get(
|
19 |
|
20 |
logging.basicConfig(level=os.getenv("LOGGER_LEVEL", logging.WARNING))
|
21 |
logger = logging.getLogger(__name__)
|
22 |
|
|
|
23 |
class ViTForImageClassification(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
def __init__(self, model_name, num_labels=24, dropout=0.25, image_size=224):
|
25 |
logger.info("Loading model")
|
26 |
super(ViTForImageClassification, self).__init__()
|
@@ -32,7 +45,8 @@ class ViTForImageClassification(nn.Module):
|
|
32 |
self.classifier = nn.Linear(self.vit.config.hidden_size, num_labels)
|
33 |
self.num_labels = num_labels
|
34 |
self.label_encoder = LabelEncoder()
|
35 |
-
self.device =
|
|
|
36 |
self.model_name = model_name
|
37 |
# To device
|
38 |
self.vit.to(self.device)
|
@@ -44,7 +58,7 @@ class ViTForImageClassification(nn.Module):
|
|
44 |
logger.info("Forwarding")
|
45 |
pixel_values = pixel_values.to(self.device)
|
46 |
outputs = self.vit(pixel_values=pixel_values)
|
47 |
-
output = self.dropout(outputs.last_hidden_state[:,0])
|
48 |
logits = self.classifier(output)
|
49 |
|
50 |
loss = None
|
@@ -61,17 +75,21 @@ class ViTForImageClassification(nn.Module):
|
|
61 |
|
62 |
def preprocess_image(self, images):
|
63 |
logger.info("Preprocessing images")
|
64 |
-
return self.feature_extractor(images, return_tensors=
|
65 |
|
66 |
-
def predict(
|
|
|
|
|
67 |
logger.info("Predicting")
|
68 |
if not isinstance(images, list):
|
69 |
images = [images]
|
70 |
classes_list = []
|
71 |
confidence_list = []
|
72 |
-
for bs in tqdm(
|
73 |
-
|
74 |
-
|
|
|
|
|
75 |
sequence_classifier_output = self.forward(images_batch, None)
|
76 |
# Get max prob
|
77 |
probs = sequence_classifier_output.logits.softmax(dim=-1).tolist()
|
@@ -96,19 +114,23 @@ class ViTForImageClassification(nn.Module):
|
|
96 |
logger.info("Loading model")
|
97 |
# Load label encoder
|
98 |
# Check if label encoder and model exists
|
99 |
-
if not os.path.exists(path + "/label_encoder.npy") or not os.path.exists(
|
|
|
|
|
100 |
logger.warning("Label encoder or model not found")
|
101 |
return
|
102 |
self.label_encoder.classes_ = np.load(path + "/label_encoder.npy")
|
103 |
# Reload classifier layer
|
104 |
-
self.classifier = nn.Linear(
|
105 |
-
|
|
|
|
|
106 |
self.load_state_dict(torch.load(path + "/model.pt", map_location=self.device))
|
107 |
self.vit.to(self.device)
|
108 |
self.vit.eval()
|
109 |
self.to(self.device)
|
110 |
self.eval()
|
111 |
-
|
112 |
def evaluate(self, images, labels):
|
113 |
logger.info("Evaluating")
|
114 |
labels = self.label_encoder.transform(labels)
|
@@ -117,11 +139,18 @@ class ViTForImageClassification(nn.Module):
|
|
117 |
# Evaluate
|
118 |
metrics = metric.compute(predictions=y_pred, references=labels)
|
119 |
f1 = f1_score.compute(predictions=y_pred, references=labels, average="macro")
|
120 |
-
print(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
print(f"Accuracy: {metrics['accuracy']}")
|
122 |
print(f"F1: {f1}")
|
123 |
-
|
124 |
-
def partial_fit(self, images, labels, save_model_path=
|
125 |
logger.info("Partial fitting")
|
126 |
# Freeze ViT model but last layer
|
127 |
# params = [param for param in self.vit.parameters()]
|
@@ -135,21 +164,27 @@ class ViTForImageClassification(nn.Module):
|
|
135 |
self.vit.eval()
|
136 |
self.eval()
|
137 |
self.evaluate(images, labels)
|
138 |
-
|
139 |
def __load_from_path(self, path, num_per_label=None):
|
140 |
images = []
|
141 |
labels = []
|
142 |
for label in os.listdir(path):
|
143 |
count = 0
|
144 |
label_folder_path = os.path.join(path, label)
|
145 |
-
for image_file in tqdm(
|
|
|
|
|
|
|
146 |
file_path = os.path.join(label_folder_path, image_file)
|
147 |
try:
|
148 |
image = Image.open(file_path)
|
149 |
-
image_shape = (
|
|
|
|
|
|
|
150 |
if image.size != image_shape:
|
151 |
image = image.resize(image_shape)
|
152 |
-
images.append(image.convert(
|
153 |
labels.append(label)
|
154 |
count += 1
|
155 |
except Exception as e:
|
@@ -157,14 +192,16 @@ class ViTForImageClassification(nn.Module):
|
|
157 |
if num_per_label is not None and count >= num_per_label:
|
158 |
break
|
159 |
return images, labels
|
160 |
-
|
161 |
-
def retrain_from_path(
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
|
|
|
|
168 |
logger.info("Retraining from path")
|
169 |
# Load path
|
170 |
images, labels = self.__load_from_path(path, num_per_label)
|
@@ -173,19 +210,20 @@ class ViTForImageClassification(nn.Module):
|
|
173 |
# Save new data
|
174 |
if save_new_data is not None:
|
175 |
logger.info("Saving new data")
|
176 |
-
for i
|
177 |
label_path = os.path.join(save_new_data, label)
|
178 |
os.makedirs(label_path, exist_ok=True)
|
179 |
-
image.save(
|
|
|
|
|
180 |
# Remove path folder
|
181 |
if remove_path:
|
182 |
logger.info("Removing feedback path")
|
183 |
shutil.rmtree(path)
|
184 |
-
|
185 |
def evaluate_from_path(self, path, num_per_label=None):
|
186 |
logger.info("Evaluating from path")
|
187 |
# Load images
|
188 |
images, labels = self.__load_from_path(path, num_per_label)
|
189 |
# Evaluate
|
190 |
self.evaluate(images, labels)
|
191 |
-
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
import shutil
|
4 |
import time
|
5 |
+
|
6 |
import numpy as np
|
|
|
|
|
|
|
|
|
7 |
import torch
|
8 |
+
import torch.nn as nn
|
9 |
from PIL import Image
|
10 |
+
from sklearn.metrics import classification_report
|
|
|
11 |
from sklearn.preprocessing import LabelEncoder
|
12 |
+
from tqdm import tqdm
|
13 |
+
from transformers import ViTFeatureExtractor, ViTModel
|
14 |
+
from transformers.modeling_outputs import SequenceClassifierOutput
|
15 |
+
|
16 |
from train import (
|
17 |
+
f1_score,
|
18 |
+
metric,
|
19 |
+
re_training,
|
20 |
)
|
21 |
|
22 |
+
data_path = os.environ.get("DATA_PATH", "./data")
|
23 |
|
24 |
logging.basicConfig(level=os.getenv("LOGGER_LEVEL", logging.WARNING))
|
25 |
logger = logging.getLogger(__name__)
|
26 |
|
27 |
+
|
28 |
class ViTForImageClassification(nn.Module):
|
29 |
+
@staticmethod
|
30 |
+
def get_device():
|
31 |
+
if torch.cuda.is_available():
|
32 |
+
return torch.device("cuda")
|
33 |
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
34 |
+
return torch.device("mps")
|
35 |
+
return torch.device("cpu")
|
36 |
+
|
37 |
def __init__(self, model_name, num_labels=24, dropout=0.25, image_size=224):
|
38 |
logger.info("Loading model")
|
39 |
super(ViTForImageClassification, self).__init__()
|
|
|
45 |
self.classifier = nn.Linear(self.vit.config.hidden_size, num_labels)
|
46 |
self.num_labels = num_labels
|
47 |
self.label_encoder = LabelEncoder()
|
48 |
+
self.device = self.get_device()
|
49 |
+
logger.info(f"Using device: {self.device}")
|
50 |
self.model_name = model_name
|
51 |
# To device
|
52 |
self.vit.to(self.device)
|
|
|
58 |
logger.info("Forwarding")
|
59 |
pixel_values = pixel_values.to(self.device)
|
60 |
outputs = self.vit(pixel_values=pixel_values)
|
61 |
+
output = self.dropout(outputs.last_hidden_state[:, 0])
|
62 |
logits = self.classifier(output)
|
63 |
|
64 |
loss = None
|
|
|
75 |
|
76 |
def preprocess_image(self, images):
|
77 |
logger.info("Preprocessing images")
|
78 |
+
return self.feature_extractor(images, return_tensors="pt")
|
79 |
|
80 |
+
def predict(
|
81 |
+
self, images, batch_size=32, classes_names=True, return_probabilities=False
|
82 |
+
):
|
83 |
logger.info("Predicting")
|
84 |
if not isinstance(images, list):
|
85 |
images = [images]
|
86 |
classes_list = []
|
87 |
confidence_list = []
|
88 |
+
for bs in tqdm(
|
89 |
+
range(0, len(images), batch_size), desc="Preprocessing training images"
|
90 |
+
):
|
91 |
+
images_batch = [image for image in images[bs : bs + batch_size]]
|
92 |
+
images_batch = self.preprocess_image(images_batch)["pixel_values"]
|
93 |
sequence_classifier_output = self.forward(images_batch, None)
|
94 |
# Get max prob
|
95 |
probs = sequence_classifier_output.logits.softmax(dim=-1).tolist()
|
|
|
114 |
logger.info("Loading model")
|
115 |
# Load label encoder
|
116 |
# Check if label encoder and model exists
|
117 |
+
if not os.path.exists(path + "/label_encoder.npy") or not os.path.exists(
|
118 |
+
path + "/model.pt"
|
119 |
+
):
|
120 |
logger.warning("Label encoder or model not found")
|
121 |
return
|
122 |
self.label_encoder.classes_ = np.load(path + "/label_encoder.npy")
|
123 |
# Reload classifier layer
|
124 |
+
self.classifier = nn.Linear(
|
125 |
+
self.vit.config.hidden_size, len(self.label_encoder.classes_)
|
126 |
+
)
|
127 |
+
|
128 |
self.load_state_dict(torch.load(path + "/model.pt", map_location=self.device))
|
129 |
self.vit.to(self.device)
|
130 |
self.vit.eval()
|
131 |
self.to(self.device)
|
132 |
self.eval()
|
133 |
+
|
134 |
def evaluate(self, images, labels):
|
135 |
logger.info("Evaluating")
|
136 |
labels = self.label_encoder.transform(labels)
|
|
|
139 |
# Evaluate
|
140 |
metrics = metric.compute(predictions=y_pred, references=labels)
|
141 |
f1 = f1_score.compute(predictions=y_pred, references=labels, average="macro")
|
142 |
+
print(
|
143 |
+
classification_report(
|
144 |
+
labels,
|
145 |
+
y_pred,
|
146 |
+
labels=[i for i in range(len(self.label_encoder.classes_))],
|
147 |
+
target_names=self.label_encoder.classes_,
|
148 |
+
)
|
149 |
+
)
|
150 |
print(f"Accuracy: {metrics['accuracy']}")
|
151 |
print(f"F1: {f1}")
|
152 |
+
|
153 |
+
def partial_fit(self, images, labels, save_model_path="new_model", num_epochs=10):
|
154 |
logger.info("Partial fitting")
|
155 |
# Freeze ViT model but last layer
|
156 |
# params = [param for param in self.vit.parameters()]
|
|
|
164 |
self.vit.eval()
|
165 |
self.eval()
|
166 |
self.evaluate(images, labels)
|
167 |
+
|
168 |
def __load_from_path(self, path, num_per_label=None):
|
169 |
images = []
|
170 |
labels = []
|
171 |
for label in os.listdir(path):
|
172 |
count = 0
|
173 |
label_folder_path = os.path.join(path, label)
|
174 |
+
for image_file in tqdm(
|
175 |
+
os.listdir(label_folder_path),
|
176 |
+
desc="Resizing images for label {}".format(label),
|
177 |
+
):
|
178 |
file_path = os.path.join(label_folder_path, image_file)
|
179 |
try:
|
180 |
image = Image.open(file_path)
|
181 |
+
image_shape = (
|
182 |
+
self.feature_extractor.size,
|
183 |
+
self.feature_extractor.size,
|
184 |
+
)
|
185 |
if image.size != image_shape:
|
186 |
image = image.resize(image_shape)
|
187 |
+
images.append(image.convert("RGB"))
|
188 |
labels.append(label)
|
189 |
count += 1
|
190 |
except Exception as e:
|
|
|
192 |
if num_per_label is not None and count >= num_per_label:
|
193 |
break
|
194 |
return images, labels
|
195 |
+
|
196 |
+
def retrain_from_path(
|
197 |
+
self,
|
198 |
+
path="./data/feedback",
|
199 |
+
num_per_label=None,
|
200 |
+
save_model_path="new_model",
|
201 |
+
remove_path=False,
|
202 |
+
num_epochs=10,
|
203 |
+
save_new_data=data_path + "/new_data",
|
204 |
+
):
|
205 |
logger.info("Retraining from path")
|
206 |
# Load path
|
207 |
images, labels = self.__load_from_path(path, num_per_label)
|
|
|
210 |
# Save new data
|
211 |
if save_new_data is not None:
|
212 |
logger.info("Saving new data")
|
213 |
+
for i, (image, label) in enumerate(zip(images, labels)):
|
214 |
label_path = os.path.join(save_new_data, label)
|
215 |
os.makedirs(label_path, exist_ok=True)
|
216 |
+
image.save(
|
217 |
+
os.path.join(label_path, str(int(time.time())) + f"_{i}.jpg")
|
218 |
+
)
|
219 |
# Remove path folder
|
220 |
if remove_path:
|
221 |
logger.info("Removing feedback path")
|
222 |
shutil.rmtree(path)
|
223 |
+
|
224 |
def evaluate_from_path(self, path, num_per_label=None):
|
225 |
logger.info("Evaluating from path")
|
226 |
# Load images
|
227 |
images, labels = self.__load_from_path(path, num_per_label)
|
228 |
# Evaluate
|
229 |
self.evaluate(images, labels)
|
|
requirements.txt
CHANGED
@@ -1,11 +1,14 @@
|
|
1 |
-
Pillow
|
2 |
-
requests
|
3 |
-
numpy
|
4 |
-
transformers
|
5 |
-
scikit-learn
|
6 |
-
datasets
|
7 |
-
streamlit
|
8 |
-
matplotlib
|
9 |
-
scikit-image
|
10 |
-
torch
|
11 |
-
torchvision
|
|
|
|
|
|
|
|
1 |
+
Pillow==10.4.0
|
2 |
+
requests==2.32.3
|
3 |
+
numpy==1.24.4
|
4 |
+
transformers==4.46.3
|
5 |
+
scikit-learn==1.3.2
|
6 |
+
datasets==3.1.0
|
7 |
+
streamlit==1.40.1
|
8 |
+
matplotlib==3.7.5
|
9 |
+
scikit-image==0.21.0
|
10 |
+
torch==2.4.1
|
11 |
+
torchvision==0.19.1
|
12 |
+
altair==5.4.1
|
13 |
+
evaluate==0.4.3
|
14 |
+
accelerate==1.0.1
|
train.py
CHANGED
@@ -1,12 +1,10 @@
|
|
|
|
1 |
import os
|
|
|
2 |
import numpy as np
|
3 |
-
from
|
4 |
-
from tqdm import tqdm
|
5 |
-
import logging
|
6 |
-
from sklearn.model_selection import train_test_split
|
7 |
-
from dataset import RetailDataset
|
8 |
from PIL import Image
|
9 |
-
from
|
10 |
from torchvision.transforms import (
|
11 |
CenterCrop,
|
12 |
Compose,
|
@@ -16,28 +14,38 @@ from torchvision.transforms import (
|
|
16 |
Resize,
|
17 |
ToTensor,
|
18 |
)
|
19 |
-
from
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
22 |
np.random.seed(42)
|
23 |
|
24 |
logging.basicConfig(level=os.getenv("LOGGER_LEVEL", logging.WARNING))
|
25 |
logger = logging.getLogger(__name__)
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
34 |
logger.info("Preparing dataset")
|
35 |
# Split the dataset in train and test
|
36 |
try:
|
37 |
-
images_train, images_test, labels_train, labels_test =
|
38 |
-
|
|
|
39 |
except ValueError:
|
40 |
-
logger.warning(
|
|
|
|
|
41 |
images_train = images
|
42 |
labels_train = labels
|
43 |
images_test = images
|
@@ -46,14 +54,24 @@ def prepare_dataset(images,
|
|
46 |
# Preprocess images using model feature extractor
|
47 |
images_train_prep = []
|
48 |
images_test_prep = []
|
49 |
-
for bs in tqdm(
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
51 |
images_train_batch = model.preprocess_image(images_train_batch)
|
52 |
-
images_train_prep.extend(images_train_batch[
|
53 |
-
for bs in tqdm(
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
55 |
images_test_batch = model.preprocess_image(images_test_batch)
|
56 |
-
images_test_prep.extend(images_test_batch[
|
57 |
|
58 |
# Create BatchFeatures
|
59 |
images_train_prep = {"pixel_values": images_train_prep}
|
@@ -61,50 +79,67 @@ def prepare_dataset(images,
|
|
61 |
images_test_prep = {"pixel_values": images_test_prep}
|
62 |
test_batch_features = BatchFeature(data=images_test_prep)
|
63 |
|
64 |
-
# Create the datasets
|
65 |
-
train_dataset = RetailDataset(
|
66 |
-
|
|
|
|
|
|
|
|
|
67 |
logger.info("Train dataset: %d images", len(labels_train))
|
68 |
logger.info("Test dataset: %d images", len(labels_test))
|
69 |
return train_dataset, test_dataset
|
70 |
|
71 |
-
|
|
|
72 |
global model
|
73 |
model = _model
|
74 |
labels = model.label_encoder.transform(labels)
|
75 |
-
normalize = Normalize(
|
|
|
|
|
|
|
76 |
def train_transforms(batch):
|
77 |
-
return Compose(
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
|
|
|
|
83 |
|
84 |
def val_transforms(batch):
|
85 |
-
return Compose(
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
|
|
|
|
|
|
91 |
train_dataset, test_dataset = prepare_dataset(
|
92 |
-
images, labels, model, .2, train_transforms, val_transforms
|
|
|
|
|
93 |
trainer = Trainer(
|
94 |
model=model,
|
95 |
args=TrainingArguments(
|
96 |
-
output_dir=
|
97 |
overwrite_output_dir=True,
|
98 |
num_train_epochs=num_epochs,
|
99 |
per_device_train_batch_size=32,
|
100 |
gradient_accumulation_steps=1,
|
101 |
learning_rate=0.000001,
|
102 |
weight_decay=0.01,
|
103 |
-
|
104 |
eval_steps=1000,
|
105 |
-
save_steps=3000
|
|
|
|
|
106 |
train_dataset=train_dataset,
|
107 |
-
eval_dataset=test_dataset
|
108 |
)
|
109 |
trainer.train()
|
110 |
-
model.save(save_model_path)
|
|
|
1 |
+
import logging
|
2 |
import os
|
3 |
+
|
4 |
import numpy as np
|
5 |
+
from evaluate import load
|
|
|
|
|
|
|
|
|
6 |
from PIL import Image
|
7 |
+
from sklearn.model_selection import train_test_split
|
8 |
from torchvision.transforms import (
|
9 |
CenterCrop,
|
10 |
Compose,
|
|
|
14 |
Resize,
|
15 |
ToTensor,
|
16 |
)
|
17 |
+
from tqdm import tqdm
|
18 |
+
from transformers import BatchFeature, Trainer, TrainingArguments
|
19 |
+
|
20 |
+
from dataset import RetailDataset
|
21 |
+
|
22 |
+
metric = load("accuracy")
|
23 |
+
f1_score = load("f1")
|
24 |
np.random.seed(42)
|
25 |
|
26 |
logging.basicConfig(level=os.getenv("LOGGER_LEVEL", logging.WARNING))
|
27 |
logger = logging.getLogger(__name__)
|
28 |
+
|
29 |
+
|
30 |
+
def prepare_dataset(
|
31 |
+
images,
|
32 |
+
labels,
|
33 |
+
model,
|
34 |
+
test_size=0.2,
|
35 |
+
train_transform=None,
|
36 |
+
val_transform=None,
|
37 |
+
batch_size=512,
|
38 |
+
):
|
39 |
logger.info("Preparing dataset")
|
40 |
# Split the dataset in train and test
|
41 |
try:
|
42 |
+
images_train, images_test, labels_train, labels_test = train_test_split(
|
43 |
+
images, labels, test_size=test_size
|
44 |
+
)
|
45 |
except ValueError:
|
46 |
+
logger.warning(
|
47 |
+
"Could not split dataset. Using all data for training and testing"
|
48 |
+
)
|
49 |
images_train = images
|
50 |
labels_train = labels
|
51 |
images_test = images
|
|
|
54 |
# Preprocess images using model feature extractor
|
55 |
images_train_prep = []
|
56 |
images_test_prep = []
|
57 |
+
for bs in tqdm(
|
58 |
+
range(0, len(images_train), batch_size), desc="Preprocessing training images"
|
59 |
+
):
|
60 |
+
images_train_batch = [
|
61 |
+
Image.fromarray(np.array(image))
|
62 |
+
for image in images_train[bs : bs + batch_size]
|
63 |
+
]
|
64 |
images_train_batch = model.preprocess_image(images_train_batch)
|
65 |
+
images_train_prep.extend(images_train_batch["pixel_values"])
|
66 |
+
for bs in tqdm(
|
67 |
+
range(0, len(images_test), batch_size), desc="Preprocessing test images"
|
68 |
+
):
|
69 |
+
images_test_batch = [
|
70 |
+
Image.fromarray(np.array(image))
|
71 |
+
for image in images_test[bs : bs + batch_size]
|
72 |
+
]
|
73 |
images_test_batch = model.preprocess_image(images_test_batch)
|
74 |
+
images_test_prep.extend(images_test_batch["pixel_values"])
|
75 |
|
76 |
# Create BatchFeatures
|
77 |
images_train_prep = {"pixel_values": images_train_prep}
|
|
|
79 |
images_test_prep = {"pixel_values": images_test_prep}
|
80 |
test_batch_features = BatchFeature(data=images_test_prep)
|
81 |
|
82 |
+
# Create the datasets with proper device
|
83 |
+
train_dataset = RetailDataset(
|
84 |
+
train_batch_features, labels_train, train_transform, device=model.device
|
85 |
+
)
|
86 |
+
test_dataset = RetailDataset(
|
87 |
+
test_batch_features, labels_test, val_transform, device=model.device
|
88 |
+
)
|
89 |
logger.info("Train dataset: %d images", len(labels_train))
|
90 |
logger.info("Test dataset: %d images", len(labels_test))
|
91 |
return train_dataset, test_dataset
|
92 |
|
93 |
+
|
94 |
+
def re_training(images, labels, _model, save_model_path="new_model", num_epochs=10):
|
95 |
global model
|
96 |
model = _model
|
97 |
labels = model.label_encoder.transform(labels)
|
98 |
+
normalize = Normalize(
|
99 |
+
mean=model.feature_extractor.image_mean, std=model.feature_extractor.image_std
|
100 |
+
)
|
101 |
+
|
102 |
def train_transforms(batch):
|
103 |
+
return Compose(
|
104 |
+
[
|
105 |
+
RandomResizedCrop(model.feature_extractor.size),
|
106 |
+
RandomHorizontalFlip(),
|
107 |
+
ToTensor(),
|
108 |
+
normalize,
|
109 |
+
]
|
110 |
+
)(batch)
|
111 |
|
112 |
def val_transforms(batch):
|
113 |
+
return Compose(
|
114 |
+
[
|
115 |
+
Resize(model.feature_extractor.size),
|
116 |
+
CenterCrop(model.feature_extractor.size),
|
117 |
+
ToTensor(),
|
118 |
+
normalize,
|
119 |
+
]
|
120 |
+
)(batch)
|
121 |
+
|
122 |
train_dataset, test_dataset = prepare_dataset(
|
123 |
+
images, labels, model, 0.2, train_transforms, val_transforms
|
124 |
+
)
|
125 |
+
|
126 |
trainer = Trainer(
|
127 |
model=model,
|
128 |
args=TrainingArguments(
|
129 |
+
output_dir="output",
|
130 |
overwrite_output_dir=True,
|
131 |
num_train_epochs=num_epochs,
|
132 |
per_device_train_batch_size=32,
|
133 |
gradient_accumulation_steps=1,
|
134 |
learning_rate=0.000001,
|
135 |
weight_decay=0.01,
|
136 |
+
eval_strategy="steps",
|
137 |
eval_steps=1000,
|
138 |
+
save_steps=3000,
|
139 |
+
use_cpu=model.device.type == "cpu", # Only force CPU if that's our device
|
140 |
+
),
|
141 |
train_dataset=train_dataset,
|
142 |
+
eval_dataset=test_dataset,
|
143 |
)
|
144 |
trainer.train()
|
145 |
+
model.save(save_model_path)
|