|
import shutil |
|
import time |
|
import numpy as np |
|
from tqdm import tqdm |
|
from transformers import ViTModel, ViTFeatureExtractor |
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
import torch.nn as nn |
|
import torch |
|
from PIL import Image |
|
import logging |
|
import os |
|
from sklearn.preprocessing import LabelEncoder |
|
from train import ( |
|
re_training, metric, f1_score, |
|
classification_report |
|
) |
|
|
|
data_path = os.environ.get('DATA_PATH', "./data") |
|
|
|
logging.basicConfig(level=os.getenv("LOGGER_LEVEL", logging.WARNING)) |
|
logger = logging.getLogger(__name__) |
|
|
|
class ViTForImageClassification(nn.Module): |
|
def __init__(self, model_name, num_labels=24, dropout=0.25, image_size=224): |
|
logger.info("Loading model") |
|
super(ViTForImageClassification, self).__init__() |
|
self.vit = ViTModel.from_pretrained(model_name) |
|
self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_name) |
|
self.feature_extractor.do_resize = True |
|
self.feature_extractor.size = image_size |
|
self.dropout = nn.Dropout(dropout) |
|
self.classifier = nn.Linear(self.vit.config.hidden_size, num_labels) |
|
self.num_labels = num_labels |
|
self.label_encoder = LabelEncoder() |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.model_name = model_name |
|
|
|
self.vit.to(self.device) |
|
self.to(self.device) |
|
self.classifier.to(self.device) |
|
logger.info("Model loaded") |
|
|
|
def forward(self, pixel_values, labels): |
|
logger.info("Forwarding") |
|
pixel_values = pixel_values.to(self.device) |
|
outputs = self.vit(pixel_values=pixel_values) |
|
output = self.dropout(outputs.last_hidden_state[:,0]) |
|
logits = self.classifier(output) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss_fct = nn.CrossEntropyLoss() |
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
|
return SequenceClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
def preprocess_image(self, images): |
|
logger.info("Preprocessing images") |
|
return self.feature_extractor(images, return_tensors='pt') |
|
|
|
def predict(self, images, batch_size=32, classes_names=True, return_probabilities=False): |
|
logger.info("Predicting") |
|
if not isinstance(images, list): |
|
images = [images] |
|
classes_list = [] |
|
confidence_list = [] |
|
for bs in tqdm(range(0, len(images), batch_size), desc="Preprocessing training images"): |
|
images_batch = [image for image in images[bs:bs+batch_size]] |
|
images_batch = self.preprocess_image(images_batch)['pixel_values'] |
|
sequence_classifier_output = self.forward(images_batch, None) |
|
|
|
probs = sequence_classifier_output.logits.softmax(dim=-1).tolist() |
|
classes = np.argmax(probs, axis=1) |
|
confidences = np.max(probs, axis=1) |
|
classes_list.extend(classes) |
|
confidence_list.extend(confidences) |
|
if classes_names: |
|
classes_list = self.label_encoder.inverse_transform(classes_list) |
|
if return_probabilities: |
|
return classes_list, confidence_list, probs |
|
return classes_list, confidence_list |
|
|
|
def save(self, path): |
|
logger.info("Saving model") |
|
os.makedirs(path, exist_ok=True) |
|
torch.save(self.state_dict(), path + "/model.pt") |
|
|
|
np.save(path + "/label_encoder.npy", self.label_encoder.classes_) |
|
|
|
def load(self, path): |
|
logger.info("Loading model") |
|
|
|
|
|
if not os.path.exists(path + "/label_encoder.npy") or not os.path.exists(path + "/model.pt"): |
|
logger.warning("Label encoder or model not found") |
|
return |
|
self.label_encoder.classes_ = np.load(path + "/label_encoder.npy") |
|
|
|
self.classifier = nn.Linear(self.vit.config.hidden_size, len(self.label_encoder.classes_)) |
|
|
|
self.load_state_dict(torch.load(path + "/model.pt", map_location=self.device)) |
|
self.vit.to(self.device) |
|
self.vit.eval() |
|
self.to(self.device) |
|
self.eval() |
|
|
|
def evaluate(self, images, labels): |
|
logger.info("Evaluating") |
|
labels = self.label_encoder.transform(labels) |
|
|
|
y_pred, _ = self.predict(images, classes_names=False) |
|
|
|
metrics = metric.compute(predictions=y_pred, references=labels) |
|
f1 = f1_score.compute(predictions=y_pred, references=labels, average="macro") |
|
print(classification_report(labels, y_pred, labels=[i for i in range(len(self.label_encoder.classes_))], target_names=self.label_encoder.classes_)) |
|
print(f"Accuracy: {metrics['accuracy']}") |
|
print(f"F1: {f1}") |
|
|
|
def partial_fit(self, images, labels, save_model_path='new_model', num_epochs=10): |
|
logger.info("Partial fitting") |
|
|
|
|
|
|
|
|
|
|
|
self.vit.train() |
|
self.train() |
|
re_training(images, labels, self, save_model_path, num_epochs) |
|
self.load(save_model_path) |
|
self.vit.eval() |
|
self.eval() |
|
self.evaluate(images, labels) |
|
|
|
def __load_from_path(self, path, num_per_label=None): |
|
images = [] |
|
labels = [] |
|
for label in os.listdir(path): |
|
count = 0 |
|
label_folder_path = os.path.join(path, label) |
|
for image_file in tqdm(os.listdir(label_folder_path), desc="Resizing images for label {}".format(label)): |
|
file_path = os.path.join(label_folder_path, image_file) |
|
try: |
|
image = Image.open(file_path) |
|
image_shape = (self.feature_extractor.size, self.feature_extractor.size) |
|
if image.size != image_shape: |
|
image = image.resize(image_shape) |
|
images.append(image.convert('RGB')) |
|
labels.append(label) |
|
count += 1 |
|
except Exception as e: |
|
print(f"ERROR - Could not resize image {file_path} - {e}") |
|
if num_per_label is not None and count >= num_per_label: |
|
break |
|
return images, labels |
|
|
|
def retrain_from_path(self, |
|
path='./data/feedback', |
|
num_per_label=None, |
|
save_model_path='new_model', |
|
remove_path=False, |
|
num_epochs=10, |
|
save_new_data=data_path + '/new_data'): |
|
logger.info("Retraining from path") |
|
|
|
images, labels = self.__load_from_path(path, num_per_label) |
|
|
|
self.partial_fit(images, labels, save_model_path, num_epochs) |
|
|
|
if save_new_data is not None: |
|
logger.info("Saving new data") |
|
for i ,(image, label) in enumerate(zip(images, labels)): |
|
label_path = os.path.join(save_new_data, label) |
|
os.makedirs(label_path, exist_ok=True) |
|
image.save(os.path.join(label_path, str(int(time.time())) + f"_{i}.jpg")) |
|
|
|
if remove_path: |
|
logger.info("Removing feedback path") |
|
shutil.rmtree(path) |
|
|
|
def evaluate_from_path(self, path, num_per_label=None): |
|
logger.info("Evaluating from path") |
|
|
|
images, labels = self.__load_from_path(path, num_per_label) |
|
|
|
self.evaluate(images, labels) |
|
|
|
|