Spaces:
Build error
Build error
import peft | |
import torch | |
import whisperx | |
import torch.nn as nn | |
from config import Config | |
from transformers import CLIPVisionModel, AutoModelForCausalLM | |
phi_model_name, model_name, device = Config.phi_model_name, Config.model_name, Config.device | |
text_model = AutoModelForCausalLM.from_pretrained(phi_model_name, | |
torch_dtype=torch.float16, | |
#device_map="cuda", | |
low_cpu_mem_usage=True, | |
return_dict=True, | |
trust_remote_code=True) | |
peft_model = peft.PeftModel.from_pretrained(text_model, 'models/29000') | |
projection = load_projection_model("models/MModalGPT-FINETUNE-step=29000-loss=3.45.ckpt", 768, 2560) | |
clip_model = CLIPVisionModel.from_pretrained(model_name) | |
audio_model = whisperx.load_model("small", device.type, compute_type="float16") | |
projection = projection.to(device) | |
peft_model = peft_model.to(device) | |
clip_model = clip_model.to(device) | |
def load_projection_model(path, clip_embed, phi_embed): | |
"""Loads a Projections model instance from a checkpoint and returns it with weights loaded. | |
Args: | |
path (str): Path to the checkpoint file. | |
Returns: | |
torch.nn.Module: The loaded Projections model instance. | |
""" | |
state_dict = torch.load(path)['state_dict'] | |
new_state_dict = {k.replace('projection.', ''): v for k, v in state_dict.items()} | |
model = Projections(clip_embed, phi_embed) | |
model.load_state_dict(new_state_dict) | |
return model | |
class Projections(nn.Module): | |
def __init__( | |
self, | |
clip_embed, | |
phi_embed, | |
num_projection_layers=6, | |
): | |
super().__init__() | |
self.norm = nn.LayerNorm(phi_embed) | |
self.output = nn.Linear(clip_embed, phi_embed) | |
self.projection_layers = nn.ModuleList( | |
[ | |
nn.Sequential( | |
nn.Linear(phi_embed, phi_embed), | |
nn.GELU(), | |
nn.Linear(phi_embed, phi_embed), | |
) | |
for _ in range(num_projection_layers) | |
] | |
) | |
def forward(self, x): | |
x = self.output(x) | |
self.norm(x) | |
for layer in self.projection_layers: | |
residual = x | |
x = layer(x) + residual | |
return x | |