|
import os
|
|
from os.path import abspath as abp
|
|
import torch
|
|
import hydra
|
|
from hydra import initialize, compose
|
|
from models.module import Geolocalizer
|
|
from omegaconf import OmegaConf, open_dict
|
|
from os.path import join
|
|
from hydra.utils import instantiate
|
|
|
|
def load_model_config(path):
|
|
|
|
|
|
path = abp(path)
|
|
rel_path = os.path.relpath(path, start=os.path.split(__file__)[0])
|
|
|
|
with initialize(version_base=None, config_path=rel_path):
|
|
cfg = compose(config_name="config", overrides=[])
|
|
|
|
checkpoint = torch.load(join(path, "last.ckpt"))
|
|
del checkpoint["state_dict"][
|
|
"model.backbone.clip.vision_model.embeddings.position_ids"
|
|
]
|
|
torch.save(checkpoint, join(path, "last2.ckpt"))
|
|
|
|
with open_dict(cfg):
|
|
cfg.checkpoint = join(path, "last2.ckpt")
|
|
|
|
cfg.num_classes = 11399
|
|
cfg.model.network.mid.instance.final_dim = cfg.num_classes * 3
|
|
cfg.model.network.head.final_dim = cfg.num_classes * 3
|
|
cfg.model.network.head.instance.quadtree_path = join(path, "quadtree_10_1000.csv")
|
|
|
|
cfg.dataset.train_dataset.path = ""
|
|
cfg.dataset.val_dataset.path = ""
|
|
cfg.dataset.test_dataset.path = ""
|
|
cfg.logger.save_dir = ""
|
|
cfg.data_dir = ""
|
|
cfg.root_dir = ""
|
|
cfg.mode = "test"
|
|
cfg.model.network.backbone.instance.path = (
|
|
"laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K"
|
|
)
|
|
return cfg.dataset.test_transform, cfg.model, join(path, "last2.ckpt"), True
|
|
|
|
def load_model(path):
|
|
transform_config, model_config, checkpoint_path, delete = load_model_config(path)
|
|
|
|
transform = instantiate(transform_config)
|
|
model = Geolocalizer.load_from_checkpoint(checkpoint_path, cfg=model_config)
|
|
if delete:
|
|
os.remove(checkpoint_path)
|
|
|
|
return model, transform |