import peft import torch import whisperx import torch.nn as nn from config import Config from transformers import CLIPVisionModel, AutoModelForCausalLM phi_model_name, model_name, device = Config.phi_model_name, Config.model_name, Config.device text_model = AutoModelForCausalLM.from_pretrained(phi_model_name, torch_dtype=torch.float16, #device_map="cuda", low_cpu_mem_usage=True, return_dict=True, trust_remote_code=True) peft_model = peft.PeftModel.from_pretrained(text_model, 'models/29000') projection = load_projection_model("models/MModalGPT-FINETUNE-step=29000-loss=3.45.ckpt", 768, 2560) clip_model = CLIPVisionModel.from_pretrained(model_name) audio_model = whisperx.load_model("small", device.type, compute_type="float16") projection = projection.to(device) peft_model = peft_model.to(device) clip_model = clip_model.to(device) def load_projection_model(path, clip_embed, phi_embed): """Loads a Projections model instance from a checkpoint and returns it with weights loaded. Args: path (str): Path to the checkpoint file. Returns: torch.nn.Module: The loaded Projections model instance. """ state_dict = torch.load(path)['state_dict'] new_state_dict = {k.replace('projection.', ''): v for k, v in state_dict.items()} model = Projections(clip_embed, phi_embed) model.load_state_dict(new_state_dict) return model class Projections(nn.Module): def __init__( self, clip_embed, phi_embed, num_projection_layers=6, ): super().__init__() self.norm = nn.LayerNorm(phi_embed) self.output = nn.Linear(clip_embed, phi_embed) self.projection_layers = nn.ModuleList( [ nn.Sequential( nn.Linear(phi_embed, phi_embed), nn.GELU(), nn.Linear(phi_embed, phi_embed), ) for _ in range(num_projection_layers) ] ) def forward(self, x): x = self.output(x) self.norm(x) for layer in self.projection_layers: residual = x x = layer(x) + residual return x