RashiAgarwal's picture
Update yolov3.py
d55a993
raw
history blame
3.96 kB
import torch
from pytorch_lightning import LightningModule
from model import YOLOv3
from dataset import YOLODataset
from loss import YoloLoss
from torch import optim
from torch.utils.data import DataLoader
import config
class YOLOV3_PL(LightningModule):
def __init__(self, in_channels=3, num_classes=config.NUM_CLASSES, batch_size=config.BATCH_SIZE,
learning_rate=config.LEARNING_RATE , num_epochs=config.NUM_EPOCHS):
super(YOLOV3_PL, self).__init__()
self.model = YOLOv3(in_channels, num_classes)
self.criterion = YoloLoss()
self.batch_size = batch_size
self.learning_rate = learning_rate
self.num_epochs = num_epochs
self.scaled_anchors = config.SCALED_ANCHORS
self.layers = self.model.layers
def train_dataloader(self):
self.train_data = YOLODataset(
config.DATASET + '/train.csv',
transform=config.train_transforms,
img_dir=config.IMG_DIR,
label_dir=config.LABEL_DIR,
anchors=config.ANCHORS
)
train_dataloader = DataLoader(
dataset=self.train_data,
batch_size=self.batch_size,
num_workers=config.NUM_WORKERS,
pin_memory=config.PIN_MEMORY,
shuffle=True
)
return train_dataloader
def val_dataloader(self):
self.valid_data = YOLODataset(
config.DATASET + '/test.csv',
transform=config.test_transforms,
img_dir=config.IMG_DIR,
label_dir=config.LABEL_DIR,
anchors=config.ANCHORS
)
return DataLoader(
dataset=self.valid_data,
batch_size=self.batch_size,
num_workers=config.NUM_WORKERS,
pin_memory=config.PIN_MEMORY,
shuffle=False
)
def test_dataloader(self):
return self.val_dataloader()
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
out = self.forward(x)
loss = self.criterion(out, y, self.scaled_anchors)
self.log(f"train_loss", loss, on_epoch=True, prog_bar=True, logger=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
out = self.forward(x)
loss = self.criterion(out, y, self.scaled_anchors)
self.log(f"val_loss", loss, on_epoch=True, prog_bar=True, logger=True)
return loss
def test_step(self, batch, batch_idx, dataloader_idx=0):
if isinstance(batch, (tuple, list)):
x, _ = batch
else:
x = batch
return self.forward(x)
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate/100, weight_decay=config.WEIGHT_DECAY)
scheduler = optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=self.learning_rate,
steps_per_epoch=len(self.train_dataloader()),
epochs=self.num_epochs,
pct_start=0.2,
div_factor=10,
three_phase=False,
final_div_factor=10,
anneal_strategy='linear'
)
return {
'optimizer': optimizer,
'lr_scheduler': {
"scheduler": scheduler,
"interval": "step",
}
}
def main():
num_classes = 20
IMAGE_SIZE = 416
INPUT_SIZE = IMAGE_SIZE
model = YOLOV3_PL(num_classes=num_classes)
from torchinfo import summary
print(summary(model, input_size=(2, 3, INPUT_SIZE, INPUT_SIZE)))
inp = torch.randn((2, 3, INPUT_SIZE, INPUT_SIZE))
out = model(inp)
assert out[0].shape == (2, 3, IMAGE_SIZE//32, IMAGE_SIZE//32, num_classes + 5)
assert out[1].shape == (2, 3, IMAGE_SIZE//16, IMAGE_SIZE//16, num_classes + 5)
assert out[2].shape == (2, 3, IMAGE_SIZE//8, IMAGE_SIZE//8, num_classes + 5)
print("Success!")
if __name__ == "__main__":
main()