ugaray96's picture
Refactor and improve model, app, and training components
0f734ea unverified
raw
history blame
4.47 kB
import logging
import os
import numpy as np
from evaluate import load
from PIL import Image
from sklearn.model_selection import train_test_split
from torchvision.transforms import (
CenterCrop,
Compose,
Normalize,
RandomHorizontalFlip,
RandomResizedCrop,
Resize,
ToTensor,
)
from tqdm import tqdm
from transformers import BatchFeature, Trainer, TrainingArguments
from dataset import RetailDataset
metric = load("accuracy")
f1_score = load("f1")
np.random.seed(42)
logging.basicConfig(level=os.getenv("LOGGER_LEVEL", logging.WARNING))
logger = logging.getLogger(__name__)
def prepare_dataset(
images,
labels,
model,
test_size=0.2,
train_transform=None,
val_transform=None,
batch_size=512,
):
logger.info("Preparing dataset")
# Split the dataset in train and test
try:
images_train, images_test, labels_train, labels_test = train_test_split(
images, labels, test_size=test_size
)
except ValueError:
logger.warning(
"Could not split dataset. Using all data for training and testing"
)
images_train = images
labels_train = labels
images_test = images
labels_test = labels
# Preprocess images using model feature extractor
images_train_prep = []
images_test_prep = []
for bs in tqdm(
range(0, len(images_train), batch_size), desc="Preprocessing training images"
):
images_train_batch = [
Image.fromarray(np.array(image))
for image in images_train[bs : bs + batch_size]
]
images_train_batch = model.preprocess_image(images_train_batch)
images_train_prep.extend(images_train_batch["pixel_values"])
for bs in tqdm(
range(0, len(images_test), batch_size), desc="Preprocessing test images"
):
images_test_batch = [
Image.fromarray(np.array(image))
for image in images_test[bs : bs + batch_size]
]
images_test_batch = model.preprocess_image(images_test_batch)
images_test_prep.extend(images_test_batch["pixel_values"])
# Create BatchFeatures
images_train_prep = {"pixel_values": images_train_prep}
train_batch_features = BatchFeature(data=images_train_prep)
images_test_prep = {"pixel_values": images_test_prep}
test_batch_features = BatchFeature(data=images_test_prep)
# Create the datasets with proper device
train_dataset = RetailDataset(
train_batch_features, labels_train, train_transform, device=model.device
)
test_dataset = RetailDataset(
test_batch_features, labels_test, val_transform, device=model.device
)
logger.info("Train dataset: %d images", len(labels_train))
logger.info("Test dataset: %d images", len(labels_test))
return train_dataset, test_dataset
def re_training(images, labels, _model, save_model_path="new_model", num_epochs=10):
global model
model = _model
labels = model.label_encoder.transform(labels)
normalize = Normalize(
mean=model.feature_extractor.image_mean, std=model.feature_extractor.image_std
)
def train_transforms(batch):
return Compose(
[
RandomResizedCrop(model.feature_extractor.size),
RandomHorizontalFlip(),
ToTensor(),
normalize,
]
)(batch)
def val_transforms(batch):
return Compose(
[
Resize(model.feature_extractor.size),
CenterCrop(model.feature_extractor.size),
ToTensor(),
normalize,
]
)(batch)
train_dataset, test_dataset = prepare_dataset(
images, labels, model, 0.2, train_transforms, val_transforms
)
trainer = Trainer(
model=model,
args=TrainingArguments(
output_dir="output",
overwrite_output_dir=True,
num_train_epochs=num_epochs,
per_device_train_batch_size=32,
gradient_accumulation_steps=1,
learning_rate=0.000001,
weight_decay=0.01,
eval_strategy="steps",
eval_steps=1000,
save_steps=3000,
use_cpu=model.device.type == "cpu", # Only force CPU if that's our device
),
train_dataset=train_dataset,
eval_dataset=test_dataset,
)
trainer.train()
model.save(save_model_path)