Spaces:
Runtime error
Runtime error
from datamodules import CIFAR10QADataModule, ImageDataModule | |
from datamodules.utils import datamodule_factory | |
from models import ImageClassificationNet | |
from models.utils import model_factory | |
from pytorch_lightning.loggers import WandbLogger | |
import argparse | |
import pytorch_lightning as pl | |
def main(args: argparse.Namespace): | |
# Seed | |
pl.seed_everything(args.seed) | |
# Create base model | |
base = model_factory(args, own_config=True) | |
# Load datamodule | |
dm = datamodule_factory(args) | |
# Load the model from the specified checkpoint | |
model = ImageClassificationNet.load_from_checkpoint( | |
args.checkpoint, | |
model=base, | |
num_train_steps=0, | |
) | |
# Create wandb logger | |
wandb_logger = WandbLogger( | |
name=f"{args.dataset}_eval_{args.base_model} ({args.from_pretrained})", | |
project="Patch-DiffMask", | |
) | |
# Create trainer | |
trainer = pl.Trainer( | |
accelerator="auto", | |
logger=wandb_logger, | |
max_epochs=1, | |
enable_progress_bar=args.enable_progress_bar, | |
) | |
# Evaluate the model | |
trainer.test(model, dm) | |
# Save the HuggingFace model to be used with --from_pretrained | |
save_dir = f"checkpoints/{args.base_model}_{args.dataset}" | |
model.model.save_pretrained(save_dir) | |
dm.feature_extractor.save_pretrained(save_dir) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--checkpoint", | |
type=str, | |
required=True, | |
help="Checkpoint to resume the training from.", | |
) | |
# Trainer | |
parser.add_argument( | |
"--enable_progress_bar", | |
action="store_true", | |
help="Whether to show progress bar during training. NOT recommended when logging to files.", | |
) | |
parser.add_argument( | |
"--seed", | |
type=int, | |
default=123, | |
help="Random seed for reproducibility.", | |
) | |
# Base (classification) model | |
parser.add_argument( | |
"--base_model", | |
type=str, | |
default="ViT", | |
choices=["ViT", "ConvNeXt"], | |
help="Base model architecture to train.", | |
) | |
parser.add_argument( | |
"--from_pretrained", | |
type=str, | |
# default="tanlq/vit-base-patch16-224-in21k-finetuned-cifar10", | |
help="The name of the pretrained HF model to fine-tune from.", | |
) | |
# Datamodule | |
ImageDataModule.add_model_specific_args(parser) | |
CIFAR10QADataModule.add_model_specific_args(parser) | |
parser.add_argument( | |
"--dataset", | |
type=str, | |
default="toy", | |
choices=["MNIST", "CIFAR10", "CIFAR10_QA", "toy"], | |
help="The dataset to use.", | |
) | |
args = parser.parse_args() | |
main(args) | |