RashiAgarwal's picture
Upload train.py
042c19a
raw
history blame
497 Bytes
"""
Main file for training Yolo model on Pascal VOC and COCO dataset
"""
import torch
from pytorch_lightning import LightningModule, Trainer, seed_everything
import torch
import warnings
warnings.filterwarnings("ignore")
torch.backends.cudnn.benchmark = True
class YOLOTraining(LightningModule):
def __init__(self,model):
super().__init__()
self.model = model
def forward(self, x):
return self.model(x)
if __name__ == "__main__":
num_classes = 20