File size: 2,796 Bytes
413d4d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
from typing import List
from torchvision.models import vit_b_16 
import torchvision.transforms as transforms

ROUND_DIGIT=3
NUM_ASPECT=5

DINO_POINT_HIGH=0.97
DINO_POINT_MID=0.9
DINO_POINT_LOW=0.8   


class MetricDINO_sim():
    def __init__(self, device="cuda") -> None: 
        """
        Initialize a class MetricDINO_sim with the specified device for testing temporal consistency of a given video.
        
        Args:
            device (str, optional): The device on which the model will run. Defaults to "cuda".
        """
        self.device = device
        self.model = vit_b_16(pretrained=True)
        self.model.to(self.device).eval()  
        self.preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    def evaluate(self, frame_list:List[Image.Image]):
        """
        Calculate the cosine similarity between the DINO features of adjacent frames of a given video to test temporal consistency, 
        then quantize the orginal output based on some predefined thresholds.
        
        Args:
            frame_list:List[Image.Image], frames of the video used in calculation.
            
        Returns:
            dino_frame_score: float, the computed DINO feature cosine similarity between each adjacent pair of frames and then averaged among all the pairs.
            quantized_ans: int, the quantized value of the above avg DINO-Sim scores based on pre-defined thresholds.
        """
        
        device = self.device
        frame_sim_list=[]
        for f_idx in range(len(frame_list)-1):
            frame_1=frame_list[f_idx]
            frame_2=frame_list[f_idx+1]
            frame_tensor_1 = self.preprocess(frame_1).unsqueeze(0).to(device)
            frame_tensor_2 = self.preprocess(frame_2).unsqueeze(0).to(device)
            with torch.no_grad():
                feat_1 = self.model(frame_tensor_1).flatten()
                feat_2 = self.model(frame_tensor_2).flatten()
            cos_sim=F.cosine_similarity(feat_1, feat_2, dim=0).item()
            frame_sim_list.append(cos_sim)   
    
        dino_frame_score = np.mean(frame_sim_list)
        quantized_ans=0
        if dino_frame_score >= DINO_POINT_HIGH:
            quantized_ans=4
        elif dino_frame_score < DINO_POINT_HIGH and dino_frame_score >= DINO_POINT_MID:
            quantized_ans=3
        elif dino_frame_score < DINO_POINT_MID and dino_frame_score >= DINO_POINT_LOW:
            quantized_ans=2
        else:
            quantized_ans=1
        return dino_frame_score, quantized_ans