File size: 2,773 Bytes
9b501ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
from torch import nn
from transformers import SiglipVisionModel, SiglipVisionConfig

# 384/14=27.428571428571427 is not an integer, so the actual pos embedding is 729, sqrt(729)*14=378. So the implementation uses the floor

class SiglipEncoder(nn.Module):
    def __init__(self, vision_config):
        super(SiglipEncoder, self).__init__()

        config = SiglipVisionConfig(**vision_config)
        self.model = SiglipVisionModel(config)

    def forward(self, images):
        outputs = self.model(images).last_hidden_state
        return outputs


class GLU(nn.Module):
    def __init__(self, args, in_features):
        super().__init__()
        self.linear_proj = nn.Linear(in_features, args.hidden_size, bias=False)
        self.norm1 = nn.LayerNorm(args.hidden_size)
        self.act1 = nn.GELU()
        self.act2 = nn.functional.silu
        self.dense_h_to_4h = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
        self.gate_proj = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
        self.dense_4h_to_h = nn.Linear(args.intermediate_size, args.hidden_size, bias=False)

    def forward(self, x):
        x = self.linear_proj(x)
        x = self.act1(self.norm1(x))
        x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x)
        x = self.dense_4h_to_h(x)
        return x


class Adapter(nn.Module):
    def __init__(self, eva_hidden_size, args):
        super().__init__()
        self.boi = nn.Parameter(torch.ones(1, 1, args.hidden_size).float())
        self.eoi = nn.Parameter(torch.ones(1, 1, args.hidden_size).float())
        self.conv = nn.Conv2d(in_channels=eva_hidden_size, out_channels=args.hidden_size, kernel_size=2, stride=2)
        self.linear_proj = GLU(args, args.hidden_size)

    def forward(self, image_emb):
        b, s, e = image_emb.shape # (b, 6400, 1792)
        grid_size = int(s**0.5)
        image_emb = image_emb.view(b, grid_size, grid_size, e).permute(0,3,1,2) # (b, 1792, 80, 80)
        image_emb = self.conv(image_emb) # (b, 4096, 40, 40)
        image_emb = image_emb.flatten(2).transpose(1, 2) # (b, 1600, 4096)
        image_emb = self.linear_proj(image_emb) # (b, 1600, 6656)
        image_emb = torch.cat([self.boi.repeat(len(image_emb), 1, 1), image_emb, self.eoi.repeat(len(image_emb), 1, 1)], dim=1)
        return image_emb


class VisionModel(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dtype = config.torch_dtype
        self.vit = SiglipEncoder(config.vision_config)
        self.adapter = Adapter(config.vision_config['hidden_size'], config)

    def forward(self, image):
        image = image.to(self.dtype)
        vit_output = self.vit(image)
        return self.adapter(vit_output).to(self.dtype)