FakeVideoDetect / networks /validator.py
ybbwcwaps
AI Video
3cc4a06
raw
history blame
No virus
1.51 kB
import functools
from typing import Mapping
import torch
import torch.nn as nn
from networks.base_model import BaseModel
import sys
from models import get_model
class Validator(BaseModel):
def name(self):
return 'Validator'
def __init__(self, opt):
super(Validator, self).__init__(opt)
self.opt = opt
self.model = get_model("FeatureTransformer")
self.clip_model = get_model("CLIP:ViT-L/14")
# for name, p in self.clip_model.named_parameters():
# if name=="fc.weight" or name=="fc.bias":
# params.append(p)
# else:
# p.requires_grad = False
# del params
self.model.to(self.device)
def set_input(self, input):
# self.input = torch.cat([self.clip_model.forward(x=video_frames, return_feature=True).unsqueeze(0) for video_frames in input[0]])
self.clip_model.to(self.device)
self.input = self.clip_model.forward(x=input[0].to(self.device).view(-1, 3, 224, 224), return_feature=True).view(-1, 16, 768)
self.clip_model.to('cpu')
self.input = self.input.to(self.device)
self.label = input[1].to(self.device).float()
def forward(self):
self.output = self.model(self.input)
self.output = self.output.view(-1).unsqueeze(1)
def load_state_dict(self, ckpt_path):
state_dict = torch.load(ckpt_path, map_location='cpu')
self.model.load_state_dict(state_dict['model'])
self.model.eval()