from enum import Enum import torch from .token_classification import ( BertPrefixForTokenClassification, RobertaPrefixForTokenClassification, DebertaPrefixForTokenClassification, DebertaV2PrefixForTokenClassification ) from .sequence_classification import ( BertPrefixForSequenceClassification, BertPromptForSequenceClassification, RobertaPrefixForSequenceClassification, RobertaPromptForSequenceClassification, DebertaPrefixForSequenceClassification, GPT2PrefixForSequenceClassification, GPT2PromptForSequenceClassification ) from .question_answering import ( BertPrefixForQuestionAnswering, RobertaPrefixModelForQuestionAnswering, DebertaPrefixModelForQuestionAnswering ) from .multiple_choice import ( BertPrefixForMultipleChoice, RobertaPrefixForMultipleChoice, DebertaPrefixForMultipleChoice, BertPromptForMultipleChoice, RobertaPromptForMultipleChoice ) from .sequence_causallm import ( BertPromptForMaskedLM, BertPrefixForMaskedLM, RobertaPromptForMaskedLM, RobertaPrefixForMaskedLM, LlamaPromptForMaskedLM, LlamaPrefixForMaskedLM, OPTPrefixForMaskedLM, OPTPromptForMaskedLM ) from transformers import ( AutoConfig, AutoModelForTokenClassification, AutoModelForSequenceClassification, AutoModelForQuestionAnswering, AutoModelForMultipleChoice ) import torch.nn.functional as F def get_loss(predict_logits, labels_ids): labels_ids = labels_ids.to(predict_logits.device) predict_logp = F.log_softmax(predict_logits, dim=-1) target_logp = predict_logp.gather(-1, labels_ids) target_logp = target_logp - 1e32 * labels_ids.eq(0) # Apply mask target_logp = torch.logsumexp(target_logp, dim=-1) return -target_logp def use_grad(base_model, use_grad): if use_grad: for param in base_model.parameters(): param.requires_grad = True base_model.train() else: for param in base_model.parameters(): param.requires_grad = False base_model.eval() def get_embeddings(model, config): """Returns the wordpiece embedding module.""" base_model = getattr(model, config.model_type) embeddings = base_model.embeddings.word_embeddings return embeddings class GradientStorage: """ This object stores the intermediate gradients of the output a the given PyTorch module, which otherwise might not be retained. """ def __init__(self, module): self._stored_gradient = None module.register_backward_hook(self.hook) def hook(self, module, grad_in, grad_out): assert grad_out is not None self._stored_gradient = grad_out[0] def reset(self): self._stored_gradient = None def get(self): return self._stored_gradient class TaskType(Enum): TOKEN_CLASSIFICATION = 1, SEQUENCE_CLASSIFICATION = 2, QUESTION_ANSWERING = 3, MULTIPLE_CHOICE = 4 PREFIX_MODELS = { "bert": { TaskType.TOKEN_CLASSIFICATION: BertPrefixForTokenClassification, TaskType.SEQUENCE_CLASSIFICATION: BertPrefixForMaskedLM, #BertPrefixForSequenceClassification, TaskType.QUESTION_ANSWERING: BertPrefixForQuestionAnswering, TaskType.MULTIPLE_CHOICE: BertPrefixForMultipleChoice }, "roberta": { TaskType.TOKEN_CLASSIFICATION: RobertaPrefixForTokenClassification, TaskType.SEQUENCE_CLASSIFICATION: RobertaPrefixForMaskedLM, #RobertaPrefixForSequenceClassification, TaskType.QUESTION_ANSWERING: RobertaPrefixModelForQuestionAnswering, TaskType.MULTIPLE_CHOICE: RobertaPrefixForMultipleChoice, }, "deberta": { TaskType.TOKEN_CLASSIFICATION: DebertaPrefixForTokenClassification, TaskType.SEQUENCE_CLASSIFICATION: DebertaPrefixForSequenceClassification, TaskType.QUESTION_ANSWERING: DebertaPrefixModelForQuestionAnswering, TaskType.MULTIPLE_CHOICE: DebertaPrefixForMultipleChoice, }, "deberta-v2": { TaskType.TOKEN_CLASSIFICATION: DebertaV2PrefixForTokenClassification, TaskType.SEQUENCE_CLASSIFICATION: None, TaskType.QUESTION_ANSWERING: None, TaskType.MULTIPLE_CHOICE: None, }, "gpt2": { TaskType.TOKEN_CLASSIFICATION: None, TaskType.SEQUENCE_CLASSIFICATION: GPT2PrefixForSequenceClassification, TaskType.QUESTION_ANSWERING: None, TaskType.MULTIPLE_CHOICE: None, }, "llama": { TaskType.TOKEN_CLASSIFICATION: None, TaskType.SEQUENCE_CLASSIFICATION: LlamaPrefixForMaskedLM, TaskType.QUESTION_ANSWERING: None, TaskType.MULTIPLE_CHOICE: None, }, "opt": { TaskType.TOKEN_CLASSIFICATION: None, TaskType.SEQUENCE_CLASSIFICATION: OPTPrefixForMaskedLM, TaskType.QUESTION_ANSWERING: None, TaskType.MULTIPLE_CHOICE: None, } } PROMPT_MODELS = { "bert": { TaskType.SEQUENCE_CLASSIFICATION: BertPromptForMaskedLM, #BertPromptForSequenceClassification, TaskType.MULTIPLE_CHOICE: BertPromptForMultipleChoice }, "roberta": { TaskType.SEQUENCE_CLASSIFICATION: RobertaPromptForMaskedLM, #RobertaPromptForSequenceClassification, TaskType.MULTIPLE_CHOICE: RobertaPromptForMultipleChoice }, "gpt2": { TaskType.SEQUENCE_CLASSIFICATION: GPT2PromptForSequenceClassification, TaskType.MULTIPLE_CHOICE: None }, "llama": { TaskType.TOKEN_CLASSIFICATION: None, TaskType.SEQUENCE_CLASSIFICATION: LlamaPromptForMaskedLM, TaskType.QUESTION_ANSWERING: None, TaskType.MULTIPLE_CHOICE: None, }, "opt": { TaskType.TOKEN_CLASSIFICATION: None, TaskType.SEQUENCE_CLASSIFICATION: OPTPromptForMaskedLM, TaskType.QUESTION_ANSWERING: None, TaskType.MULTIPLE_CHOICE: None, } } AUTO_MODELS = { TaskType.TOKEN_CLASSIFICATION: AutoModelForTokenClassification, TaskType.SEQUENCE_CLASSIFICATION: AutoModelForSequenceClassification, TaskType.QUESTION_ANSWERING: AutoModelForQuestionAnswering, TaskType.MULTIPLE_CHOICE: AutoModelForMultipleChoice, } def get_model(model_args, task_type: TaskType, config: AutoConfig, fix_bert: bool = False, tokenizer=None): model_name_or_path = f'openlm-research/{model_args.model_name_or_path}' if "llama" in model_args.model_name_or_path else model_args.model_name_or_path if model_args.prefix: config.hidden_dropout_prob = model_args.hidden_dropout_prob config.pre_seq_len = model_args.pre_seq_len config.prefix_projection = model_args.prefix_projection config.prefix_hidden_size = model_args.prefix_hidden_size model_class = PREFIX_MODELS[config.model_type][task_type] if "opt" in model_args.model_name_or_path: model_name_or_path = f'facebook/{model_args.model_name_or_path}' model = model_class.from_pretrained( model_name_or_path, config=config, revision=model_args.model_revision, trust_remote_code=True ) elif "llama" in model_args.model_name_or_path: model_name_or_path = f'openlm-research/{model_args.model_name_or_path}' model = model_class.from_pretrained( model_name_or_path, config=config, trust_remote_code=True, torch_dtype=torch.float32, device_map='auto', ) else: model = model_class.from_pretrained( model_name_or_path, config=config, trust_remote_code=True, revision=model_args.model_revision ) elif model_args.prompt: config.pre_seq_len = model_args.pre_seq_len model_class = PROMPT_MODELS[config.model_type][task_type] if "opt" in model_args.model_name_or_path: model_name_or_path = f'facebook/opt-1.3b' model = model_class.from_pretrained( model_name_or_path, config=config, revision=model_args.model_revision, trust_remote_code=True ) elif "llama" in model_args.model_name_or_path: model_name_or_path = f'openlm-research/{model_args.model_name_or_path}' model = model_class.from_pretrained( model_name_or_path, config=config, trust_remote_code=True, torch_dtype=torch.float32, device_map='auto', ) else: model = model_class.from_pretrained( model_name_or_path, config=config, revision=model_args.model_revision, trust_remote_code=True ) else: model_class = AUTO_MODELS[task_type] model = model_class.from_pretrained( model_name_or_path, config=config, revision=model_args.model_revision, ) base_param = 0 if fix_bert: if config.model_type == "bert": for param in model.bert.parameters(): param.requires_grad = False for _, param in model.bert.named_parameters(): base_param += param.numel() elif config.model_type == "roberta": for param in model.roberta.parameters(): param.requires_grad = False for _, param in model.roberta.named_parameters(): base_param += param.numel() elif config.model_type == "deberta": for param in model.deberta.parameters(): param.requires_grad = False for _, param in model.deberta.named_parameters(): base_param += param.numel() elif config.model_type == "gpt2": for param in model.gpt2.parameters(): param.requires_grad = False for _, param in model.gpt2.named_parameters(): base_param += param.numel() all_param = 0 for _, param in model.named_parameters(): all_param += param.numel() total_param = all_param - base_param print('***** Backborn param:{:0.3f}M, P-Tuning-V2 param is {} *****'.format(all_param, total_param)) return model def get_model_deprecated(model_args, task_type: TaskType, config: AutoConfig, fix_bert: bool = False): if model_args.prefix: config.hidden_dropout_prob = model_args.hidden_dropout_prob config.pre_seq_len = model_args.pre_seq_len config.prefix_projection = model_args.prefix_projection config.prefix_hidden_size = model_args.prefix_hidden_size if task_type == TaskType.TOKEN_CLASSIFICATION: from model.token_classification import BertPrefixModel, RobertaPrefixModel, DebertaPrefixModel, DebertaV2PrefixModel elif task_type == TaskType.SEQUENCE_CLASSIFICATION: from model.sequence_classification import BertPrefixModel, RobertaPrefixModel, DebertaPrefixModel, DebertaV2PrefixModel elif task_type == TaskType.QUESTION_ANSWERING: from model.question_answering import BertPrefixModel, RobertaPrefixModel, DebertaPrefixModel, DebertaV2PrefixModel elif task_type == TaskType.MULTIPLE_CHOICE: from model.multiple_choice import BertPrefixModel if config.model_type == "bert": model = BertPrefixModel.from_pretrained( model_args.model_name_or_path, config=config, revision=model_args.model_revision, ) elif config.model_type == "roberta": model = RobertaPrefixModel.from_pretrained( model_args.model_name_or_path, config=config, revision=model_args.model_revision, ) elif config.model_type == "deberta": model = DebertaPrefixModel.from_pretrained( model_args.model_name_or_path, config=config, revision=model_args.model_revision, ) elif config.model_type == "deberta-v2": model = DebertaV2PrefixModel.from_pretrained( model_args.model_name_or_path, config=config, revision=model_args.model_revision, ) else: raise NotImplementedError elif model_args.prompt: config.pre_seq_len = model_args.pre_seq_len from model.sequence_classification import BertPromptModel, RobertaPromptModel if config.model_type == "bert": model = BertPromptModel.from_pretrained( model_args.model_name_or_path, config=config, revision=model_args.model_revision, ) elif config.model_type == "roberta": model = RobertaPromptModel.from_pretrained( model_args.model_name_or_path, config=config, revision=model_args.model_revision, ) else: raise NotImplementedError else: if task_type == TaskType.TOKEN_CLASSIFICATION: model = AutoModelForTokenClassification.from_pretrained( model_args.model_name_or_path, config=config, revision=model_args.model_revision, ) elif task_type == TaskType.SEQUENCE_CLASSIFICATION: model = AutoModelForSequenceClassification.from_pretrained( model_args.model_name_or_path, config=config, revision=model_args.model_revision, ) elif task_type == TaskType.QUESTION_ANSWERING: model = AutoModelForQuestionAnswering.from_pretrained( model_args.model_name_or_path, config=config, revision=model_args.model_revision, ) elif task_type == TaskType.MULTIPLE_CHOICE: model = AutoModelForMultipleChoice.from_pretrained( model_args.model_name_or_path, config=config, revision=model_args.model_revision, ) bert_param = 0 if fix_bert: if config.model_type == "bert": for param in model.bert.parameters(): param.requires_grad = False for _, param in model.bert.named_parameters(): bert_param += param.numel() elif config.model_type == "roberta": for param in model.roberta.parameters(): param.requires_grad = False for _, param in model.roberta.named_parameters(): bert_param += param.numel() elif config.model_type == "deberta": for param in model.deberta.parameters(): param.requires_grad = False for _, param in model.deberta.named_parameters(): bert_param += param.numel() all_param = 0 for _, param in model.named_parameters(): all_param += param.numel() total_param = all_param - bert_param print('***** total param is {} *****'.format(total_param)) return model