|
import torch |
|
from pytorch_lightning import LightningModule |
|
from model import YOLOv3 |
|
from dataset import YOLODataset |
|
from loss import YoloLoss |
|
from torch import optim |
|
from torch.utils.data import DataLoader |
|
import config |
|
|
|
class YOLOV3_PL(LightningModule): |
|
def __init__(self, in_channels=3, num_classes=config.NUM_CLASSES, batch_size=config.BATCH_SIZE, |
|
learning_rate=config.LEARNING_RATE , num_epochs=config.NUM_EPOCHS): |
|
super(YOLOV3_PL, self).__init__() |
|
self.model = YOLOv3(in_channels, num_classes) |
|
self.criterion = YoloLoss() |
|
self.batch_size = batch_size |
|
self.learning_rate = learning_rate |
|
self.num_epochs = num_epochs |
|
self.scaled_anchors = config.SCALED_ANCHORS |
|
self.layers = self.model.layers |
|
|
|
def train_dataloader(self): |
|
self.train_data = YOLODataset( |
|
config.DATASET + '/train.csv', |
|
transform=config.train_transforms, |
|
img_dir=config.IMG_DIR, |
|
label_dir=config.LABEL_DIR, |
|
anchors=config.ANCHORS |
|
) |
|
|
|
train_dataloader = DataLoader( |
|
dataset=self.train_data, |
|
batch_size=self.batch_size, |
|
num_workers=config.NUM_WORKERS, |
|
pin_memory=config.PIN_MEMORY, |
|
shuffle=True |
|
) |
|
|
|
return train_dataloader |
|
|
|
def val_dataloader(self): |
|
|
|
self.valid_data = YOLODataset( |
|
config.DATASET + '/test.csv', |
|
transform=config.test_transforms, |
|
img_dir=config.IMG_DIR, |
|
label_dir=config.LABEL_DIR, |
|
anchors=config.ANCHORS |
|
) |
|
|
|
return DataLoader( |
|
dataset=self.valid_data, |
|
batch_size=self.batch_size, |
|
num_workers=config.NUM_WORKERS, |
|
pin_memory=config.PIN_MEMORY, |
|
shuffle=False |
|
) |
|
|
|
def test_dataloader(self): |
|
return self.val_dataloader() |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
def training_step(self, batch, batch_idx): |
|
x, y = batch |
|
out = self.forward(x) |
|
loss = self.criterion(out, y, self.scaled_anchors) |
|
self.log(f"train_loss", loss, on_epoch=True, prog_bar=True, logger=True) |
|
|
|
return loss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
x, y = batch |
|
out = self.forward(x) |
|
loss = self.criterion(out, y, self.scaled_anchors) |
|
self.log(f"val_loss", loss, on_epoch=True, prog_bar=True, logger=True) |
|
return loss |
|
|
|
def test_step(self, batch, batch_idx, dataloader_idx=0): |
|
if isinstance(batch, (tuple, list)): |
|
x, _ = batch |
|
else: |
|
x = batch |
|
return self.forward(x) |
|
|
|
def configure_optimizers(self): |
|
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate/100, weight_decay=config.WEIGHT_DECAY) |
|
scheduler = optim.lr_scheduler.OneCycleLR( |
|
optimizer, |
|
max_lr=self.learning_rate, |
|
steps_per_epoch=len(self.train_dataloader()), |
|
epochs=self.num_epochs, |
|
pct_start=0.2, |
|
div_factor=10, |
|
three_phase=False, |
|
final_div_factor=10, |
|
anneal_strategy='linear' |
|
) |
|
return { |
|
'optimizer': optimizer, |
|
'lr_scheduler': { |
|
"scheduler": scheduler, |
|
"interval": "step", |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
num_classes = 20 |
|
IMAGE_SIZE = 416 |
|
INPUT_SIZE = IMAGE_SIZE |
|
model = YOLOV3_PL(num_classes=num_classes) |
|
from torchinfo import summary |
|
print(summary(model, input_size=(2, 3, INPUT_SIZE, INPUT_SIZE))) |
|
inp = torch.randn((2, 3, INPUT_SIZE, INPUT_SIZE)) |
|
out = model(inp) |
|
assert out[0].shape == (2, 3, IMAGE_SIZE//32, IMAGE_SIZE//32, num_classes + 5) |
|
assert out[1].shape == (2, 3, IMAGE_SIZE//16, IMAGE_SIZE//16, num_classes + 5) |
|
assert out[2].shape == (2, 3, IMAGE_SIZE//8, IMAGE_SIZE//8, num_classes + 5) |
|
print("Success!") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|