import torch from dataclasses import dataclass, asdict from torchvision import transforms @dataclass class TrainingConfig: lr: float = 1e-4 epochs: int = 2 batch_size: int = 16 data_dir: str = "data/eyes/train" eyes_model_path: str = "models/eyes_model.pt" dataset_state: str = 'train' # train / val / test train_size: float = 0.8 gpus: int = 1 if torch.cuda.is_available() else 0 num_workers: int = 8 # for dataloader transforms = transforms.Compose([ transforms.Grayscale(3), transforms.ToTensor(), transforms.Normalize(0.5, 0.5), transforms.RandomResizedCrop((64, 64), scale=(0.5, 1)), transforms.RandomHorizontalFlip(p=0.2), ]) def asdict(self): return asdict(self) @dataclass class InferenceConfig: eyes_model_path: str = "models/eyes_model.pt" transforms = transforms.Compose([ transforms.ToPILImage(), transforms.Grayscale(3), transforms.ToTensor(), transforms.Normalize(0.5, 0.5), transforms.Resize((64, 64)), ]) classification_threshold = 0.3 def asdict(self): return asdict(self) @dataclass class BlurrinessConfig: threshold: int = 250 data_dir: str = "data/blur/" def asdict(self): return asdict(self)