File size: 4,944 Bytes
03f6091
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from typing import *
import random

import torch
from torch import Tensor

from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_info
import clip
import torch.nn.functional as F
from tqdm import tqdm
from PIL import Image

def read_image(imgid):
    from pathlib import Path
    vanilla = Path(imgid)
    fixed = Path(f"data_en/images/{imgid}")
    assert not (vanilla.exists() == fixed.exists()) # 両者共に存在/不在だと困る

    path = vanilla if vanilla.exists() else fixed
    return Image.open(path).convert("RGB")


class MID():
    def __init__(self,device="cuda"):
        self.clip, self.clip_preprocess = clip.load("ViT-B/32", device=device)
        self.device = device

    def batchify(self, targets, batch_size):
        return [targets[i:i+batch_size] for i in range(0,len(targets),batch_size)]

    def __call__(self, mt_list, refs_list, img_list, no_ref=False):
        B = 32
        mt_list, refs_list, img_list = [self.batchify(x,B) for x in [mt_list,refs_list,img_list]]
        scores = []
        assert len(mt_list) == len(refs_list) == len(img_list)
        for mt, refs, imgs in (pbar:= tqdm(zip(mt_list,refs_list, img_list),total=len(mt_list))):
            pbar.set_description(f"MID")
            imgs = [read_image(imgid) for imgid in imgs]
            refs_token = []
            for ref_list in refs:
                refs_token.append([clip.tokenize(ref,truncate=True).to(self.device) for ref in ref_list])

            refs = torch.cat([torch.cat(ref,dim=0) for ref in refs_token], dim=0)
            mts = clip.tokenize([x for x in mt],truncate=True).to(self.device)
            imgs = torch.cat([self.clip_preprocess(img).unsqueeze(0) for img in imgs],dim=0).to(self.device)

            imgs = self.clip.encode_image(imgs)
            mts = self.clip.encode_text(mts)
            refs = self.clip.encode_text(refs)
            compute_pmi(imgs,refs,mts)
        
        return scores



def log_det(X):
    eigenvalues = X.svd()[1]
    return eigenvalues.log().sum()


def robust_inv(x, eps=0):
    Id = torch.eye(x.shape[0]).to(x.device)
    return (x + eps * Id).inverse()


def exp_smd(a, b, reduction=True):
    a_inv = robust_inv(a)
    if reduction:
        assert b.shape[0] == b.shape[1]
        return (a_inv @ b).trace()
    else:
        return (b @ a_inv @ b.t()).diag()


def compute_pmi(x: Tensor, y: Tensor, x0: Tensor, limit: int = 30000,
                 reduction: bool = True, full: bool = False) -> Tensor:
    r"""
    A numerical stable version of the MID score.

    Args:
        x (Tensor): features for real samples
        y (Tensor): features for text samples
        x0 (Tensor): features for fake samples
        limit (int): limit the number of samples
        reduction (bool): returns the expectation of PMI if true else sample-wise results
        full (bool): use full samples from real images

    Returns:
        Scalar value of the mutual information divergence between the sets.
    """
    N = x.shape[0]
    excess = N - limit
    if 0 < excess:
        if not full:
            x = x[:-excess]
            y = y[:-excess]
        x0 = x0[:-excess]
    N = x.shape[0]
    M = x0.shape[0]

    assert N >= x.shape[1], "not full rank for matrix inversion!"
    if x.shape[0] < 30000:
        rank_zero_info("if it underperforms, please consider to use "
                       "the epsilon of 5e-4 or something else.")

    z = torch.cat([x, y], dim=-1)
    z0 = torch.cat([x0, y[:x0.shape[0]]], dim=-1)
    x_mean = x.mean(dim=0, keepdim=True)
    y_mean = y.mean(dim=0, keepdim=True)
    z_mean = torch.cat([x_mean, y_mean], dim=-1)
    x0_mean = x0.mean(dim=0, keepdim=True)
    z0_mean = z0.mean(dim=0, keepdim=True)

    X = (x - x_mean).t() @ (x - x_mean) / (N - 1)
    Y = (y - y_mean).t() @ (y - y_mean) / (N - 1)
    Z = (z - z_mean).t() @ (z - z_mean) / (N - 1)
    X0 = (x0 - x_mean).t() @ (x0 - x_mean) / (M - 1)  # use the reference mean
    Z0 = (z0 - z_mean).t() @ (z0 - z_mean) / (M - 1)  # use the reference mean

    alternative_comp = False
    # notice that it may have numerical unstability. we don't use this.
    if alternative_comp:
        def factorized_cov(x, m):
            N = x.shape[0]
            return (x.t() @ x - N * m.t() @ m) / (N - 1)
        X0 = factorized_cov(x0, x_mean)
        Z0 = factorized_cov(z0, z_mean)

    # assert double precision
    for _ in [X, Y, Z, X0, Z0]:
        assert _.dtype == torch.float64

    # Expectation of PMI
    mi = (log_det(X) + log_det(Y) - log_det(Z)) / 2
    rank_zero_info(f"MI of real images: {mi:.4f}")

    # Squared Mahalanobis Distance terms
    if reduction:
        smd = (exp_smd(X, X0) + exp_smd(Y, Y) - exp_smd(Z, Z0)) / 2
    else:
        smd = (exp_smd(X, x0 - x_mean, False) + exp_smd(Y, y - y_mean, False)
               - exp_smd(Z, z0 - z_mean, False)) / 2
        mi = mi.unsqueeze(0)  # for broadcasting

    return mi + smd