File size: 2,439 Bytes
b7be07b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import peft
import torch
import whisperx
import torch.nn as nn
from config import Config
from transformers import CLIPVisionModel, AutoModelForCausalLM


phi_model_name, model_name, device = Config.phi_model_name, Config.model_name, Config.device

text_model = AutoModelForCausalLM.from_pretrained(phi_model_name,
                                                  torch_dtype=torch.float16,
                                                  #device_map="cuda",
                                                  low_cpu_mem_usage=True,
                                                  return_dict=True,
                                                  trust_remote_code=True)

peft_model = peft.PeftModel.from_pretrained(text_model, 'models/29000')
projection = load_projection_model("models/MModalGPT-FINETUNE-step=29000-loss=3.45.ckpt", 768, 2560)

clip_model = CLIPVisionModel.from_pretrained(model_name)
audio_model = whisperx.load_model("small", device.type, compute_type="float16")


projection = projection.to(device)
peft_model = peft_model.to(device)
clip_model = clip_model.to(device)


def load_projection_model(path, clip_embed, phi_embed):
    """Loads a Projections model instance from a checkpoint and returns it with weights loaded.

    Args:
        path (str): Path to the checkpoint file.

    Returns:
        torch.nn.Module: The loaded Projections model instance.
    """

    state_dict = torch.load(path)['state_dict']
    new_state_dict = {k.replace('projection.', ''): v for k, v in state_dict.items()}

    model = Projections(clip_embed, phi_embed)
    model.load_state_dict(new_state_dict)

    return model


class Projections(nn.Module):
    def __init__(
        self,
        clip_embed,
        phi_embed,
        num_projection_layers=6,
    ):
        super().__init__()

        self.norm = nn.LayerNorm(phi_embed)
        self.output = nn.Linear(clip_embed, phi_embed)
        self.projection_layers = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Linear(phi_embed, phi_embed),
                    nn.GELU(),
                    nn.Linear(phi_embed, phi_embed),
                )
                for _ in range(num_projection_layers)
            ]
        )

    def forward(self, x):
        x = self.output(x)
        self.norm(x)
        for layer in self.projection_layers:
            residual = x
            x = layer(x) + residual

        return x