Spaces:
Sleeping
Sleeping
import functools | |
from typing import Mapping | |
import torch | |
import torch.nn as nn | |
from networks.base_model import BaseModel | |
import sys | |
from models import get_model | |
class Validator(BaseModel): | |
def name(self): | |
return 'Validator' | |
def __init__(self, opt): | |
super(Validator, self).__init__(opt) | |
self.opt = opt | |
self.model = get_model("FeatureTransformer") | |
self.clip_model = get_model("CLIP:ViT-L/14") | |
# for name, p in self.clip_model.named_parameters(): | |
# if name=="fc.weight" or name=="fc.bias": | |
# params.append(p) | |
# else: | |
# p.requires_grad = False | |
# del params | |
self.model.to(self.device) | |
def set_input(self, input): | |
# self.input = torch.cat([self.clip_model.forward(x=video_frames, return_feature=True).unsqueeze(0) for video_frames in input[0]]) | |
self.clip_model.to(self.device) | |
self.input = self.clip_model.forward(x=input[0].to(self.device).view(-1, 3, 224, 224), return_feature=True).view(-1, 16, 768) | |
self.clip_model.to('cpu') | |
self.input = self.input.to(self.device) | |
self.label = input[1].to(self.device).float() | |
def forward(self): | |
self.output = self.model(self.input) | |
self.output = self.output.view(-1).unsqueeze(1) | |
def load_state_dict(self, ckpt_path): | |
state_dict = torch.load(ckpt_path, map_location='cpu') | |
self.model.load_state_dict(state_dict['model']) | |
self.model.eval() |