import peft import torch import whisperx import torch.nn as nn from transformers import AutoProcessor, AutoTokenizer from transformers import CLIPVisionModel, AutoModelForCausalLM class Projections(nn.Module): def __init__( self, clip_embed, phi_embed, num_projection_layers=6, ): super().__init__() self.norm = nn.LayerNorm(phi_embed) self.output = nn.Linear(clip_embed, phi_embed) self.projection_layers = nn.ModuleList( [ nn.Sequential( nn.Linear(phi_embed, phi_embed), nn.GELU(), nn.Linear(phi_embed, phi_embed), ) for _ in range(num_projection_layers) ] ) def forward(self, x): x = self.output(x) self.norm(x) for layer in self.projection_layers: residual = x x = layer(x) + residual return x def load_projection_model(path, clip_embed, phi_embed): """Loads a Projections model instance from a checkpoint and returns it with weights loaded. Args: path (str): Path to the checkpoint file. Returns: torch.nn.Module: The loaded Projections model instance. """ state_dict = torch.load(path)['state_dict'] new_state_dict = {k.replace('projection.', ''): v for k, v in state_dict.items()} model = Projections(clip_embed, phi_embed) model.load_state_dict(new_state_dict) return model class Config: EOS_TOKEN_ID = 50256 QUESTION_ANSWER_SEPARATOR_ID = 50295 # Special token ID for question-answer separation IMAGE_SEPARATOR_TOKENS = [685, 36259, 14041, 60, 220] phi_model_name = "microsoft/phi-2" model_name = "openai/clip-vit-base-patch32" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") processor = AutoProcessor.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True) projection = load_projection_model("models/MModalGPT-FINETUNE-continued-step=10100-loss=1.16.ckpt", 768, 2560) clip_model = CLIPVisionModel.from_pretrained(model_name) audio_model = whisperx.load_model("small", device.type, compute_type="float16") text_model = AutoModelForCausalLM.from_pretrained(phi_model_name, torch_dtype=torch.float16, #device_map="cuda", low_cpu_mem_usage=True, return_dict=True, trust_remote_code=True) peft_model = peft.PeftModel.from_pretrained(text_model, 'models/10100')