animals_detection / train_model.py
wiklif's picture
added Dockerfile and result folder
3b3bfbb
raw
history blame
3.13 kB
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
# Ustawienia parametrów treningu
img_width, img_height = 224, 224 # Wymiary obrazu wymagane przez model ResNet
batch_size = 32 # Liczba obrazów przetwarzanych na raz podczas treningu
epochs = 10 # Liczba epok treningu
learning_rate = 0.001 # Wskaźnik uczenia się dla optymalizatora
model_path = './result/animal_classifier_resnet.pth' # Ścieżka do zapisu wytrenowanego modelu
# Sprawdzenie, czy jest dostępny GPU i przypisanie urządzenia do zmiennej `device`
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Wyłączenie CuDNN, chyba że twoja karta wspiera bibliotekę CuDNN (NVIDIA CUDA Deep Neural Network library)
# torch.backends.cudnn.enabled = False
# Transformacje danych wejściowych (zmiana rozmiaru, normalizacja)
transform = transforms.Compose([
transforms.Resize((img_width, img_height)), # Zmiana rozmiaru obrazu
transforms.ToTensor(), # Konwersja obrazu do tensoru
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # Normalizacja obrazu
])
# Przygotowanie danych treningowych z katalogu `raw-img`
data_dir = 'raw-img' # Ścieżka do katalogu z obrazami treningowymi
train_dataset = datasets.ImageFolder(data_dir, transform=transform) # Wczytanie obrazów i zastosowanie transformacji
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) # Tworzenie loadera danych
# Użycie pretrenowanego modelu ResNet18
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
num_ftrs = model.fc.in_features # Liczba wejściowych cech ostatniej warstwy
model.fc = nn.Linear(num_ftrs, len(train_dataset.classes)) # Zastąpienie ostatniej warstwy dopasowanej do liczby klas w danych
# Przeniesienie modelu na GPU, jeśli jest dostępny
model = model.to(device)
# Definicja funkcji kosztu (CrossEntropyLoss) i optymalizatora (Adam)
criterion = nn.CrossEntropyLoss() # Funkcja kosztu
optimizer = optim.Adam(model.parameters(), lr=learning_rate) # Optymalizator
# Trening modelu
for epoch in range(epochs): # Pętla przez epoki
model.train() # Ustawienie modelu w tryb treningowy
running_loss = 0.0 # Zmienna do śledzenia straty
for inputs, labels in train_loader: # Pętla przez batch'e danych
# Przeniesienie danych na GPU, jeśli jest dostępny
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad() # Zerowanie gradientów
outputs = model(inputs) # Przekazanie danych przez model
loss = criterion(outputs, labels) # Obliczenie straty
loss.backward() # Propagacja wsteczna
optimizer.step() # Aktualizacja wag modelu
running_loss += loss.item() # Akumulacja straty
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader)}") # Wyświetlenie średniej straty na epokę
# Zapisywanie wytrenowanego modelu do pliku
torch.save(model.state_dict(), model_path)
print(f"Model saved to {model_path}")