File size: 2,787 Bytes
00b53ff
b7be07b
00b53ff
 
b7be07b
00b53ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7be07b
 
 
 
 
 
 
 
 
 
 
 
00b53ff
 
6d345f4
00b53ff
 
 
 
 
 
 
 
 
 
 
6d345f4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import peft
import torch
import whisperx
import torch.nn as nn
from transformers import AutoProcessor, AutoTokenizer
from transformers import CLIPVisionModel, AutoModelForCausalLM


class Projections(nn.Module):
    def __init__(
        self,
        clip_embed,
        phi_embed,
        num_projection_layers=6,
    ):
        super().__init__()

        self.norm = nn.LayerNorm(phi_embed)
        self.output = nn.Linear(clip_embed, phi_embed)
        self.projection_layers = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Linear(phi_embed, phi_embed),
                    nn.GELU(),
                    nn.Linear(phi_embed, phi_embed),
                )
                for _ in range(num_projection_layers)
            ]
        )

    def forward(self, x):
        x = self.output(x)
        self.norm(x)
        for layer in self.projection_layers:
            residual = x
            x = layer(x) + residual

        return x

def load_projection_model(path, clip_embed, phi_embed):
    """Loads a Projections model instance from a checkpoint and returns it with weights loaded.

    Args:
        path (str): Path to the checkpoint file.

    Returns:
        torch.nn.Module: The loaded Projections model instance.
    """

    state_dict = torch.load(path)['state_dict']
    new_state_dict = {k.replace('projection.', ''): v for k, v in state_dict.items()}

    model = Projections(clip_embed, phi_embed)
    model.load_state_dict(new_state_dict)

    return model

class Config:

    EOS_TOKEN_ID = 50256
    QUESTION_ANSWER_SEPARATOR_ID = 50295  # Special token ID for question-answer separation
    IMAGE_SEPARATOR_TOKENS = [685, 36259, 14041, 60, 220]

    phi_model_name = "microsoft/phi-2"
    model_name = "openai/clip-vit-base-patch32"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    processor = AutoProcessor.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True)
    
    projection = load_projection_model("models/MModalGPT-FINETUNE-step=3200-loss=1.01.ckpt", 768, 2560)

    clip_model = CLIPVisionModel.from_pretrained(model_name)
    audio_model = whisperx.load_model("small", device.type, compute_type="float16")

    text_model = AutoModelForCausalLM.from_pretrained(phi_model_name,
                                                  torch_dtype=torch.float16,
                                                  #device_map="cuda",
                                                  low_cpu_mem_usage=True,
                                                  return_dict=True,
                                                  trust_remote_code=True)
    
    peft_model = peft.PeftModel.from_pretrained(text_model, 'models/3200')