CHM-Corr / ExtractEmbedding.py
taesiri's picture
added CHM classification
d526dbf
raw
history blame
No virus
1.48 kB
import time
import os
import torch
import numpy as np
import torchvision
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from tqdm import tqdm
import pickle
import argparse
from PIL import Image
concat = lambda x: np.concatenate(x, axis=0)
to_np = lambda x: x.data.to("cpu").numpy()
class Wrapper(torch.nn.Module):
def __init__(self, model):
super(Wrapper, self).__init__()
self.model = model
self.avgpool_output = None
self.query = None
self.cossim_value = {}
def fw_hook(module, input, output):
self.avgpool_output = output.squeeze()
self.model.avgpool.register_forward_hook(fw_hook)
def forward(self, input):
_ = self.model(input)
return self.avgpool_output
def __repr__(self):
return "Wrappper"
def QueryToEmbedding(query_path):
dataset_transform = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
model = torchvision.models.resnet50(pretrained=True)
model.eval()
myw = Wrapper(model)
query_pil = Image.open(query_path)
query_pt = dataset_transform(query_pil)
with torch.no_grad():
embedding = to_np(myw(query_pt.unsqueeze(0)))
return np.asarray([embedding])