custom_yolo_v3 / custom_library /lightning_model.py
mkthoma's picture
code library push
be6ec35
import torch
import torch.optim as optim
import lightning.pytorch as pl
from tqdm import tqdm
from .model import YOLOv3
from .loss import YoloLoss
from .utils import get_loaders, load_checkpoint, check_class_accuracy, intersection_over_union
from . import config
from torch.optim.lr_scheduler import OneCycleLR
class YOLOv3Lightning(pl.LightningModule):
def __init__(self, config, lr_value=0):
super().__init__()
self.automatic_optimization =True
self.config = config
self.model = YOLOv3(num_classes=self.config.NUM_CLASSES)
self.loss_fn = YoloLoss()
if lr_value == 0:
self.learning_rate = self.config.LEARNING_RATE
else:
self.learning_rate = lr_value
def forward(self, x):
return self.model(x)
def configure_optimizers(self):
optimizer = optim.Adam(self.model.parameters(), lr=self.config.LEARNING_RATE, weight_decay=self.config.WEIGHT_DECAY)
EPOCHS = self.config.NUM_EPOCHS * 2 // 5
scheduler = OneCycleLR(optimizer, max_lr=1E-3, steps_per_epoch=len(self.train_dataloader()), epochs=EPOCHS, pct_start=5/EPOCHS, div_factor=100, three_phase=False, final_div_factor=100, anneal_strategy='linear')
return [optimizer], [{"scheduler": scheduler, "interval": "step", "frequency": 1}]
def train_dataloader(self):
train_loader, _, _ = get_loaders(
train_csv_path=self.config.DATASET + "/train.csv",
test_csv_path=self.config.DATASET + "/test.csv",
)
return train_loader
def training_step(self, batch, batch_idx):
x, y = batch
y0, y1, y2 = (y[0].to(self.device),y[1].to(self.device),y[2].to(self.device))
out = self(x)
loss = (self.loss_fn(out[0], y0, self.scaled_anchors[0])
+ self.loss_fn(out[1], y1, self.scaled_anchors[1])
+ self.loss_fn(out[2], y2, self.scaled_anchors[2]))
self.log('train_loss', loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
return loss
def val_dataloader(self):
_, _, val_loader = get_loaders(
train_csv_path=self.config.DATASET + "/train.csv",
test_csv_path=self.config.DATASET + "/test.csv",
)
return val_loader
def validation_step(self, batch, batch_idx):
x, y = batch
y0, y1, y2 = (
y[0].to(self.device),
y[1].to(self.device),
y[2].to(self.device),
)
out = self(x)
loss = (
self.loss_fn(out[0], y0, self.scaled_anchors[0])
+ self.loss_fn(out[1], y1, self.scaled_anchors[1])
+ self.loss_fn(out[2], y2, self.scaled_anchors[2])
)
self.log('val_loss', loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
def test_dataloader(self):
_, test_loader, _ = get_loaders(
train_csv_path=self.config.DATASET + "/train.csv",
test_csv_path=self.config.DATASET + "/test.csv",
)
return test_loader
def test_step(self, batch, batch_idx):
x, y = batch
y0, y1, y2 = (
y[0].to(self.device),
y[1].to(self.device),
y[2].to(self.device),
)
out = self(x)
loss = (
self.loss_fn(out[0], y0, self.scaled_anchors[0])
+ self.loss_fn(out[1], y1, self.scaled_anchors[1])
+ self.loss_fn(out[2], y2, self.scaled_anchors[2])
)
self.log('test_loss', loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
def on_train_start(self):
if self.config.LOAD_MODEL:
load_checkpoint(self.config.CHECKPOINT_FILE, self.model, self.optimizers(), self.config.LEARNING_RATE)
self.scaled_anchors = (
torch.tensor(self.config.ANCHORS)
* torch.tensor(self.config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
).to(self.device)