Sijuade's picture
Upload 6 files
b7be07b verified
raw
history blame
No virus
2.44 kB
import peft
import torch
import whisperx
import torch.nn as nn
from config import Config
from transformers import CLIPVisionModel, AutoModelForCausalLM
phi_model_name, model_name, device = Config.phi_model_name, Config.model_name, Config.device
text_model = AutoModelForCausalLM.from_pretrained(phi_model_name,
torch_dtype=torch.float16,
#device_map="cuda",
low_cpu_mem_usage=True,
return_dict=True,
trust_remote_code=True)
peft_model = peft.PeftModel.from_pretrained(text_model, 'models/29000')
projection = load_projection_model("models/MModalGPT-FINETUNE-step=29000-loss=3.45.ckpt", 768, 2560)
clip_model = CLIPVisionModel.from_pretrained(model_name)
audio_model = whisperx.load_model("small", device.type, compute_type="float16")
projection = projection.to(device)
peft_model = peft_model.to(device)
clip_model = clip_model.to(device)
def load_projection_model(path, clip_embed, phi_embed):
"""Loads a Projections model instance from a checkpoint and returns it with weights loaded.
Args:
path (str): Path to the checkpoint file.
Returns:
torch.nn.Module: The loaded Projections model instance.
"""
state_dict = torch.load(path)['state_dict']
new_state_dict = {k.replace('projection.', ''): v for k, v in state_dict.items()}
model = Projections(clip_embed, phi_embed)
model.load_state_dict(new_state_dict)
return model
class Projections(nn.Module):
def __init__(
self,
clip_embed,
phi_embed,
num_projection_layers=6,
):
super().__init__()
self.norm = nn.LayerNorm(phi_embed)
self.output = nn.Linear(clip_embed, phi_embed)
self.projection_layers = nn.ModuleList(
[
nn.Sequential(
nn.Linear(phi_embed, phi_embed),
nn.GELU(),
nn.Linear(phi_embed, phi_embed),
)
for _ in range(num_projection_layers)
]
)
def forward(self, x):
x = self.output(x)
self.norm(x)
for layer in self.projection_layers:
residual = x
x = layer(x) + residual
return x