import inspect import math import os import warnings from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers import PretrainedConfig from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from transformers.modeling_utils import PreTrainedModel from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) import numpy as np from transformers import Phi3ForCausalLM import inspect import math import os import warnings from typing import List, Optional, Tuple, Union from tqdm import tqdm, trange import torch import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from transformers.modeling_utils import PreTrainedModel from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) import numpy as np import torch import os import argparse import json from tqdm import tqdm from typing import cast, List, Union, Tuple from transformers import AutoTokenizer, AutoModel # pylint: disable=C0413 from peft import LoraConfig, get_peft_model, TaskType import time import torch.nn.functional as F import sys import time import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from tqdm import tqdm, trange from collections import defaultdict from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, AutoConfig import torch.distributed as dist from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint import sys import torch import torch.nn as nn import torch.nn.functional as F import math import re class MAB_POST(nn.Module): def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False): super(MAB_POST, self).__init__() self.dim_V = dim_V self.num_heads = num_heads self.fc_q = nn.Linear(dim_Q, dim_V) self.fc_k = nn.Linear(dim_K, dim_V) self.fc_v = nn.Linear(dim_K, dim_V) if ln: self.ln0 = nn.LayerNorm(dim_V) self.ln1 = nn.LayerNorm(dim_V) self.fc_o = nn.Linear(dim_V, dim_V) nn.init.xavier_uniform_(self.fc_q.weight) nn.init.xavier_uniform_(self.fc_k.weight) nn.init.xavier_uniform_(self.fc_v.weight) nn.init.xavier_uniform_(self.fc_o.weight) def forward(self, Q, K, pad_mask=None): Q_ = self.fc_q(Q) K_, V_ = self.fc_k(K), self.fc_v(K) dim_split = self.dim_V // self.num_heads Q_ = torch.cat(Q_.split(dim_split, 2), 0) K_ = torch.cat(K_.split(dim_split, 2), 0) V_ = torch.cat(V_.split(dim_split, 2), 0) pad_mask = pad_mask.unsqueeze(1).repeat(self.num_heads, Q.size(1), 1) score = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V) score = score.masked_fill(pad_mask == 0, -1e12) A = torch.softmax(score, 2) A = A * pad_mask O = torch.cat(A.bmm(V_).split(Q.size(0), 0), 2) O = Q + O O = O if getattr(self, 'ln0', None) is None else self.ln0(O) O = O + F.relu(self.fc_o(O)) O = O if getattr(self, 'ln1', None) is None else self.ln1(O) return O class PMA(nn.Module): def __init__(self, dim, compress_dim, num_heads, num_seeds, ln=False, pma_mode=None): super(PMA, self).__init__() self.S = nn.Parameter(torch.Tensor(1, num_seeds, compress_dim)) nn.init.xavier_uniform_(self.S) if pma_mode == 'post_normal': self.mab = MAB_POST(compress_dim, dim, compress_dim, num_heads, ln=ln) elif pma_mode == 'pre_normal': self.mab = MAB_PRE_NORMAL(compress_dim, dim, compress_dim, num_heads, ln=ln) elif pma_mode == 'pre_gptj': self.mab = MAB_PRE_GPTJ(compress_dim, dim, compress_dim, num_heads, ln=ln) else: raise ValueError(f"Error, the pma_mode {pma_mode} is not implemented !") def forward(self, X, pad_mask): if self.S.dtype != torch.bfloat16: X = X.float() return self.mab(self.S.repeat(X.size(0), 1, 1), X, pad_mask) class CodeFuse_CGE_Small(PreTrainedModel): def __init__(self, config): super().__init__(config) self.plm_model = Phi3ForCausalLM(config) self.embedding_method = config.embedding_method self.inf_seq_length = config.inf_seq_length self.padding_side = config.padding_side self.keep_max_layer = config.keep_max_layer self.emb_dim = self.plm_model.model.embed_tokens.weight.size(1) self.num_heads = config.pma_num_heads self.ln = config.pma_ln self.norm = config.pma_norm self.compress_dim = config.compress_dim self.pma_mode = config.pma_norm_mode self.mha_pma = PMA(self.emb_dim, self.compress_dim, self.num_heads, 1, ln=self.ln, pma_mode=self.pma_mode) def last_embedding(self, A, index): bs, seq, emb = A.size() res = A[torch.arange(bs), index, :] return res def mean_embedding(self, A, mask): bs, seq, emb = A.size() res = (A * (mask.unsqueeze(-1))).sum(1) / (mask.sum(1).unsqueeze(-1)) return res def weighted_embedding(self, A, mask): weights = (torch.arange(start=1, end=A.size(1) + 1).unsqueeze(0).unsqueeze(-1).expand(A.size()).float()).to(A.device) input_mask_expanded = (mask.squeeze(1).unsqueeze(-1).expand(A.size()).float()).to(A.device) sum_embedding = torch.sum(A * input_mask_expanded * weights, dim=1) sum_mask = torch.sum(input_mask_expanded * weights, dim=1) weighted_embedding = sum_embedding / sum_mask return weighted_embedding def pma_embedding(self, A, mask): res = self.mha_pma(A, mask).squeeze(1) return res def get_sentence_embedding(self, embedding_method, **inputs): outputs = self.plm_model(inputs['input_ids'], inputs['attention_mask'], output_hidden_states=True) if embedding_method == 'last': embedding = outputs.hidden_states[self.keep_max_layer] index = inputs['attention_mask'].sum(-1).long() - 1 res_embedding = self.last_embedding(embedding, index) elif embedding_method == 'mean': embedding = outputs.hidden_states[self.keep_max_layer] res_embedding = self.mean_embedding(embedding, inputs['attention_mask']) elif embedding_method == 'weighted': embedding = outputs.hidden_states[self.keep_max_layer] res_embedding = self.weighted_embedding(embedding, inputs['attention_mask']) elif embedding_method == 'pma': embedding = outputs.hidden_states[self.keep_max_layer] attention_mask = inputs['attention_mask'] res_embedding = self.pma_embedding(embedding, attention_mask) else: logger.debug('Error, no {} way to obtain embbedings'.format(embedding_method)) if not self.norm: res_embedding = torch.nn.functional.normalize(res_embedding, p=2.0, dim=-1, eps=1e-12, out=None) return res_embedding def encode(self, tokenizer, sentences, batch_size=32, convert_to_numpy=True, convert_to_tensor=False, show_progress_bar=True, max_seq_length=None, **kwargs): if max_seq_length is None: max_seq_length = self.inf_seq_length input_is_string = False if isinstance(sentences, str) or not hasattr(sentences, "__len__"): sentences = [sentences] input_is_string = True all_embeddings = [] length_sorted_idx = np.argsort([-len(s) for s in sentences]) sentences_sorted = [sentences[idx] for idx in length_sorted_idx] with torch.no_grad(): for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar): sentences_batch = sentences_sorted[start_index: start_index + batch_size] with torch.no_grad(): inputs = tokenizer(sentences_batch, padding=True, truncation=True, max_length=max_seq_length, add_special_tokens=False, return_tensors='pt').to(self.plm_model.device) embeddings = self.get_sentence_embedding(self.embedding_method, **inputs) embeddings = embeddings.detach() if convert_to_numpy: if embeddings.dtype == torch.bfloat16: embeddings = embeddings.cpu().to(torch.float32) else: embeddings = embeddings.cpu() all_embeddings.extend(embeddings) all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)] if convert_to_tensor: all_embeddings = torch.stack(all_embeddings) elif convert_to_numpy: all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) if input_is_string: all_embeddings = all_embeddings[0] return all_embeddings