Sijuade commited on
Commit
00b53ff
1 Parent(s): 137a199

Update config.py

Browse files
Files changed (1) hide show
  1. config.py +69 -1
config.py CHANGED
@@ -1,5 +1,59 @@
 
1
  import torch
 
 
2
  from transformers import AutoProcessor, AutoTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  class Config:
5
 
@@ -12,4 +66,18 @@ class Config:
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
14
  processor = AutoProcessor.from_pretrained(model_name)
15
- tokenizer = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import peft
2
  import torch
3
+ import whisperx
4
+ import torch.nn as nn
5
  from transformers import AutoProcessor, AutoTokenizer
6
+ from transformers import CLIPVisionModel, AutoModelForCausalLM
7
+
8
+
9
+ class Projections(nn.Module):
10
+ def __init__(
11
+ self,
12
+ clip_embed,
13
+ phi_embed,
14
+ num_projection_layers=6,
15
+ ):
16
+ super().__init__()
17
+
18
+ self.norm = nn.LayerNorm(phi_embed)
19
+ self.output = nn.Linear(clip_embed, phi_embed)
20
+ self.projection_layers = nn.ModuleList(
21
+ [
22
+ nn.Sequential(
23
+ nn.Linear(phi_embed, phi_embed),
24
+ nn.GELU(),
25
+ nn.Linear(phi_embed, phi_embed),
26
+ )
27
+ for _ in range(num_projection_layers)
28
+ ]
29
+ )
30
+
31
+ def forward(self, x):
32
+ x = self.output(x)
33
+ self.norm(x)
34
+ for layer in self.projection_layers:
35
+ residual = x
36
+ x = layer(x) + residual
37
+
38
+ return x
39
+
40
+ def load_projection_model(path, clip_embed, phi_embed):
41
+ """Loads a Projections model instance from a checkpoint and returns it with weights loaded.
42
+
43
+ Args:
44
+ path (str): Path to the checkpoint file.
45
+
46
+ Returns:
47
+ torch.nn.Module: The loaded Projections model instance.
48
+ """
49
+
50
+ state_dict = torch.load(path)['state_dict']
51
+ new_state_dict = {k.replace('projection.', ''): v for k, v in state_dict.items()}
52
+
53
+ model = Projections(clip_embed, phi_embed)
54
+ model.load_state_dict(new_state_dict)
55
+
56
+ return model
57
 
58
  class Config:
59
 
 
66
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
67
 
68
  processor = AutoProcessor.from_pretrained(model_name)
69
+ tokenizer = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True)
70
+
71
+ projection = load_projection_model("models/MModalGPT-FINETUNE-step=29000-loss=3.45.ckpt", 768, 2560)
72
+
73
+ clip_model = CLIPVisionModel.from_pretrained(model_name)
74
+ audio_model = whisperx.load_model("small", device.type, compute_type="float16")
75
+
76
+ text_model = AutoModelForCausalLM.from_pretrained(phi_model_name,
77
+ torch_dtype=torch.float16,
78
+ #device_map="cuda",
79
+ low_cpu_mem_usage=True,
80
+ return_dict=True,
81
+ trust_remote_code=True)
82
+
83
+ peft_model = peft.PeftModel.from_pretrained(text_model, 'models/29000')