Sijuade commited on
Commit
4c6225d
1 Parent(s): 00b53ff

Update networks.py

Browse files
Files changed (1) hide show
  1. networks.py +5 -73
networks.py CHANGED
@@ -1,79 +1,11 @@
1
- import peft
2
- import torch
3
- import whisperx
4
- import torch.nn as nn
5
  from config import Config
6
- from transformers import CLIPVisionModel, AutoModelForCausalLM
7
 
 
8
 
9
- phi_model_name, model_name, device = Config.phi_model_name, Config.model_name, Config.device
10
-
11
- text_model = AutoModelForCausalLM.from_pretrained(phi_model_name,
12
- torch_dtype=torch.float16,
13
- #device_map="cuda",
14
- low_cpu_mem_usage=True,
15
- return_dict=True,
16
- trust_remote_code=True)
17
-
18
- peft_model = peft.PeftModel.from_pretrained(text_model, 'models/29000')
19
- projection = load_projection_model("models/MModalGPT-FINETUNE-step=29000-loss=3.45.ckpt", 768, 2560)
20
-
21
- clip_model = CLIPVisionModel.from_pretrained(model_name)
22
- audio_model = whisperx.load_model("small", device.type, compute_type="float16")
23
-
24
 
25
  projection = projection.to(device)
26
  peft_model = peft_model.to(device)
27
- clip_model = clip_model.to(device)
28
-
29
-
30
- def load_projection_model(path, clip_embed, phi_embed):
31
- """Loads a Projections model instance from a checkpoint and returns it with weights loaded.
32
-
33
- Args:
34
- path (str): Path to the checkpoint file.
35
-
36
- Returns:
37
- torch.nn.Module: The loaded Projections model instance.
38
- """
39
-
40
- state_dict = torch.load(path)['state_dict']
41
- new_state_dict = {k.replace('projection.', ''): v for k, v in state_dict.items()}
42
-
43
- model = Projections(clip_embed, phi_embed)
44
- model.load_state_dict(new_state_dict)
45
-
46
- return model
47
-
48
-
49
- class Projections(nn.Module):
50
- def __init__(
51
- self,
52
- clip_embed,
53
- phi_embed,
54
- num_projection_layers=6,
55
- ):
56
- super().__init__()
57
-
58
- self.norm = nn.LayerNorm(phi_embed)
59
- self.output = nn.Linear(clip_embed, phi_embed)
60
- self.projection_layers = nn.ModuleList(
61
- [
62
- nn.Sequential(
63
- nn.Linear(phi_embed, phi_embed),
64
- nn.GELU(),
65
- nn.Linear(phi_embed, phi_embed),
66
- )
67
- for _ in range(num_projection_layers)
68
- ]
69
- )
70
-
71
- def forward(self, x):
72
- x = self.output(x)
73
- self.norm(x)
74
- for layer in self.projection_layers:
75
- residual = x
76
- x = layer(x) + residual
77
-
78
- return x
79
-
 
 
 
 
 
1
  from config import Config
 
2
 
3
+ device = Config.device
4
 
5
+ audio_model = Config.audio_model
6
+ text_model, peft_model = Config.text_model, Config.peft_model
7
+ projection, clip_model = Config.projection, Config.clip_model
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  projection = projection.to(device)
10
  peft_model = peft_model.to(device)
11
+ clip_model = clip_model.to(device)