Spaces:
Sleeping
Sleeping
import torch | |
import torch.optim as optim | |
import lightning.pytorch as pl | |
from tqdm import tqdm | |
from .model import YOLOv3 | |
from .loss import YoloLoss | |
from .utils import get_loaders, load_checkpoint, check_class_accuracy, intersection_over_union | |
from . import config | |
from torch.optim.lr_scheduler import OneCycleLR | |
class YOLOv3Lightning(pl.LightningModule): | |
def __init__(self, config, lr_value=0): | |
super().__init__() | |
self.automatic_optimization =True | |
self.config = config | |
self.model = YOLOv3(num_classes=self.config.NUM_CLASSES) | |
self.loss_fn = YoloLoss() | |
if lr_value == 0: | |
self.learning_rate = self.config.LEARNING_RATE | |
else: | |
self.learning_rate = lr_value | |
def forward(self, x): | |
return self.model(x) | |
def configure_optimizers(self): | |
optimizer = optim.Adam(self.model.parameters(), lr=self.config.LEARNING_RATE, weight_decay=self.config.WEIGHT_DECAY) | |
EPOCHS = self.config.NUM_EPOCHS * 2 // 5 | |
scheduler = OneCycleLR(optimizer, max_lr=1E-3, steps_per_epoch=len(self.train_dataloader()), epochs=EPOCHS, pct_start=5/EPOCHS, div_factor=100, three_phase=False, final_div_factor=100, anneal_strategy='linear') | |
return [optimizer], [{"scheduler": scheduler, "interval": "step", "frequency": 1}] | |
def train_dataloader(self): | |
train_loader, _, _ = get_loaders( | |
train_csv_path=self.config.DATASET + "/train.csv", | |
test_csv_path=self.config.DATASET + "/test.csv", | |
) | |
return train_loader | |
def training_step(self, batch, batch_idx): | |
x, y = batch | |
y0, y1, y2 = (y[0].to(self.device),y[1].to(self.device),y[2].to(self.device)) | |
out = self(x) | |
loss = (self.loss_fn(out[0], y0, self.scaled_anchors[0]) | |
+ self.loss_fn(out[1], y1, self.scaled_anchors[1]) | |
+ self.loss_fn(out[2], y2, self.scaled_anchors[2])) | |
self.log('train_loss', loss, prog_bar=True, logger=True, on_step=True, on_epoch=True) | |
return loss | |
def val_dataloader(self): | |
_, _, val_loader = get_loaders( | |
train_csv_path=self.config.DATASET + "/train.csv", | |
test_csv_path=self.config.DATASET + "/test.csv", | |
) | |
return val_loader | |
def validation_step(self, batch, batch_idx): | |
x, y = batch | |
y0, y1, y2 = ( | |
y[0].to(self.device), | |
y[1].to(self.device), | |
y[2].to(self.device), | |
) | |
out = self(x) | |
loss = ( | |
self.loss_fn(out[0], y0, self.scaled_anchors[0]) | |
+ self.loss_fn(out[1], y1, self.scaled_anchors[1]) | |
+ self.loss_fn(out[2], y2, self.scaled_anchors[2]) | |
) | |
self.log('val_loss', loss, prog_bar=True, logger=True, on_step=True, on_epoch=True) | |
def test_dataloader(self): | |
_, test_loader, _ = get_loaders( | |
train_csv_path=self.config.DATASET + "/train.csv", | |
test_csv_path=self.config.DATASET + "/test.csv", | |
) | |
return test_loader | |
def test_step(self, batch, batch_idx): | |
x, y = batch | |
y0, y1, y2 = ( | |
y[0].to(self.device), | |
y[1].to(self.device), | |
y[2].to(self.device), | |
) | |
out = self(x) | |
loss = ( | |
self.loss_fn(out[0], y0, self.scaled_anchors[0]) | |
+ self.loss_fn(out[1], y1, self.scaled_anchors[1]) | |
+ self.loss_fn(out[2], y2, self.scaled_anchors[2]) | |
) | |
self.log('test_loss', loss, prog_bar=True, logger=True, on_step=True, on_epoch=True) | |
def on_train_start(self): | |
if self.config.LOAD_MODEL: | |
load_checkpoint(self.config.CHECKPOINT_FILE, self.model, self.optimizers(), self.config.LEARNING_RATE) | |
self.scaled_anchors = ( | |
torch.tensor(self.config.ANCHORS) | |
* torch.tensor(self.config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) | |
).to(self.device) | |