homeway's picture
Add application file
7713b1f
from enum import Enum
import torch
from .token_classification import (
BertPrefixForTokenClassification,
RobertaPrefixForTokenClassification,
DebertaPrefixForTokenClassification,
DebertaV2PrefixForTokenClassification
)
from .sequence_classification import (
BertPrefixForSequenceClassification,
BertPromptForSequenceClassification,
RobertaPrefixForSequenceClassification,
RobertaPromptForSequenceClassification,
DebertaPrefixForSequenceClassification,
GPT2PrefixForSequenceClassification,
GPT2PromptForSequenceClassification
)
from .question_answering import (
BertPrefixForQuestionAnswering,
RobertaPrefixModelForQuestionAnswering,
DebertaPrefixModelForQuestionAnswering
)
from .multiple_choice import (
BertPrefixForMultipleChoice,
RobertaPrefixForMultipleChoice,
DebertaPrefixForMultipleChoice,
BertPromptForMultipleChoice,
RobertaPromptForMultipleChoice
)
from .sequence_causallm import (
BertPromptForMaskedLM,
BertPrefixForMaskedLM,
RobertaPromptForMaskedLM,
RobertaPrefixForMaskedLM,
LlamaPromptForMaskedLM,
LlamaPrefixForMaskedLM,
OPTPrefixForMaskedLM,
OPTPromptForMaskedLM
)
from transformers import (
AutoConfig,
AutoModelForTokenClassification,
AutoModelForSequenceClassification,
AutoModelForQuestionAnswering,
AutoModelForMultipleChoice
)
import torch.nn.functional as F
def get_loss(predict_logits, labels_ids):
labels_ids = labels_ids.to(predict_logits.device)
predict_logp = F.log_softmax(predict_logits, dim=-1)
target_logp = predict_logp.gather(-1, labels_ids)
target_logp = target_logp - 1e32 * labels_ids.eq(0) # Apply mask
target_logp = torch.logsumexp(target_logp, dim=-1)
return -target_logp
def use_grad(base_model, use_grad):
if use_grad:
for param in base_model.parameters():
param.requires_grad = True
base_model.train()
else:
for param in base_model.parameters():
param.requires_grad = False
base_model.eval()
def get_embeddings(model, config):
"""Returns the wordpiece embedding module."""
base_model = getattr(model, config.model_type)
embeddings = base_model.embeddings.word_embeddings
return embeddings
class GradientStorage:
"""
This object stores the intermediate gradients of the output a the given PyTorch module, which
otherwise might not be retained.
"""
def __init__(self, module):
self._stored_gradient = None
module.register_backward_hook(self.hook)
def hook(self, module, grad_in, grad_out):
assert grad_out is not None
self._stored_gradient = grad_out[0]
def reset(self):
self._stored_gradient = None
def get(self):
return self._stored_gradient
class TaskType(Enum):
TOKEN_CLASSIFICATION = 1,
SEQUENCE_CLASSIFICATION = 2,
QUESTION_ANSWERING = 3,
MULTIPLE_CHOICE = 4
PREFIX_MODELS = {
"bert": {
TaskType.TOKEN_CLASSIFICATION: BertPrefixForTokenClassification,
TaskType.SEQUENCE_CLASSIFICATION: BertPrefixForMaskedLM, #BertPrefixForSequenceClassification,
TaskType.QUESTION_ANSWERING: BertPrefixForQuestionAnswering,
TaskType.MULTIPLE_CHOICE: BertPrefixForMultipleChoice
},
"roberta": {
TaskType.TOKEN_CLASSIFICATION: RobertaPrefixForTokenClassification,
TaskType.SEQUENCE_CLASSIFICATION: RobertaPrefixForMaskedLM, #RobertaPrefixForSequenceClassification,
TaskType.QUESTION_ANSWERING: RobertaPrefixModelForQuestionAnswering,
TaskType.MULTIPLE_CHOICE: RobertaPrefixForMultipleChoice,
},
"deberta": {
TaskType.TOKEN_CLASSIFICATION: DebertaPrefixForTokenClassification,
TaskType.SEQUENCE_CLASSIFICATION: DebertaPrefixForSequenceClassification,
TaskType.QUESTION_ANSWERING: DebertaPrefixModelForQuestionAnswering,
TaskType.MULTIPLE_CHOICE: DebertaPrefixForMultipleChoice,
},
"deberta-v2": {
TaskType.TOKEN_CLASSIFICATION: DebertaV2PrefixForTokenClassification,
TaskType.SEQUENCE_CLASSIFICATION: None,
TaskType.QUESTION_ANSWERING: None,
TaskType.MULTIPLE_CHOICE: None,
},
"gpt2": {
TaskType.TOKEN_CLASSIFICATION: None,
TaskType.SEQUENCE_CLASSIFICATION: GPT2PrefixForSequenceClassification,
TaskType.QUESTION_ANSWERING: None,
TaskType.MULTIPLE_CHOICE: None,
},
"llama": {
TaskType.TOKEN_CLASSIFICATION: None,
TaskType.SEQUENCE_CLASSIFICATION: LlamaPrefixForMaskedLM,
TaskType.QUESTION_ANSWERING: None,
TaskType.MULTIPLE_CHOICE: None,
},
"opt": {
TaskType.TOKEN_CLASSIFICATION: None,
TaskType.SEQUENCE_CLASSIFICATION: OPTPrefixForMaskedLM,
TaskType.QUESTION_ANSWERING: None,
TaskType.MULTIPLE_CHOICE: None,
}
}
PROMPT_MODELS = {
"bert": {
TaskType.SEQUENCE_CLASSIFICATION: BertPromptForMaskedLM, #BertPromptForSequenceClassification,
TaskType.MULTIPLE_CHOICE: BertPromptForMultipleChoice
},
"roberta": {
TaskType.SEQUENCE_CLASSIFICATION: RobertaPromptForMaskedLM, #RobertaPromptForSequenceClassification,
TaskType.MULTIPLE_CHOICE: RobertaPromptForMultipleChoice
},
"gpt2": {
TaskType.SEQUENCE_CLASSIFICATION: GPT2PromptForSequenceClassification,
TaskType.MULTIPLE_CHOICE: None
},
"llama": {
TaskType.TOKEN_CLASSIFICATION: None,
TaskType.SEQUENCE_CLASSIFICATION: LlamaPromptForMaskedLM,
TaskType.QUESTION_ANSWERING: None,
TaskType.MULTIPLE_CHOICE: None,
},
"opt": {
TaskType.TOKEN_CLASSIFICATION: None,
TaskType.SEQUENCE_CLASSIFICATION: OPTPromptForMaskedLM,
TaskType.QUESTION_ANSWERING: None,
TaskType.MULTIPLE_CHOICE: None,
}
}
AUTO_MODELS = {
TaskType.TOKEN_CLASSIFICATION: AutoModelForTokenClassification,
TaskType.SEQUENCE_CLASSIFICATION: AutoModelForSequenceClassification,
TaskType.QUESTION_ANSWERING: AutoModelForQuestionAnswering,
TaskType.MULTIPLE_CHOICE: AutoModelForMultipleChoice,
}
def get_model(model_args, task_type: TaskType, config: AutoConfig, fix_bert: bool = False, tokenizer=None):
model_name_or_path = f'openlm-research/{model_args.model_name_or_path}' if "llama" in model_args.model_name_or_path else model_args.model_name_or_path
if model_args.prefix:
config.hidden_dropout_prob = model_args.hidden_dropout_prob
config.pre_seq_len = model_args.pre_seq_len
config.prefix_projection = model_args.prefix_projection
config.prefix_hidden_size = model_args.prefix_hidden_size
model_class = PREFIX_MODELS[config.model_type][task_type]
if "opt" in model_args.model_name_or_path:
model_name_or_path = f'facebook/{model_args.model_name_or_path}'
model = model_class.from_pretrained(
model_name_or_path,
config=config,
revision=model_args.model_revision,
trust_remote_code=True
)
elif "llama" in model_args.model_name_or_path:
model_name_or_path = f'openlm-research/{model_args.model_name_or_path}'
model = model_class.from_pretrained(
model_name_or_path,
config=config,
trust_remote_code=True,
torch_dtype=torch.float32,
device_map='auto',
)
else:
model = model_class.from_pretrained(
model_name_or_path,
config=config,
trust_remote_code=True,
revision=model_args.model_revision
)
elif model_args.prompt:
config.pre_seq_len = model_args.pre_seq_len
model_class = PROMPT_MODELS[config.model_type][task_type]
if "opt" in model_args.model_name_or_path:
model_name_or_path = f'facebook/opt-1.3b'
model = model_class.from_pretrained(
model_name_or_path,
config=config,
revision=model_args.model_revision,
trust_remote_code=True
)
elif "llama" in model_args.model_name_or_path:
model_name_or_path = f'openlm-research/{model_args.model_name_or_path}'
model = model_class.from_pretrained(
model_name_or_path,
config=config,
trust_remote_code=True,
torch_dtype=torch.float32,
device_map='auto',
)
else:
model = model_class.from_pretrained(
model_name_or_path,
config=config,
revision=model_args.model_revision,
trust_remote_code=True
)
else:
model_class = AUTO_MODELS[task_type]
model = model_class.from_pretrained(
model_name_or_path,
config=config,
revision=model_args.model_revision,
)
base_param = 0
if fix_bert:
if config.model_type == "bert":
for param in model.bert.parameters():
param.requires_grad = False
for _, param in model.bert.named_parameters():
base_param += param.numel()
elif config.model_type == "roberta":
for param in model.roberta.parameters():
param.requires_grad = False
for _, param in model.roberta.named_parameters():
base_param += param.numel()
elif config.model_type == "deberta":
for param in model.deberta.parameters():
param.requires_grad = False
for _, param in model.deberta.named_parameters():
base_param += param.numel()
elif config.model_type == "gpt2":
for param in model.gpt2.parameters():
param.requires_grad = False
for _, param in model.gpt2.named_parameters():
base_param += param.numel()
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
total_param = all_param - base_param
print('***** Backborn param:{:0.3f}M, P-Tuning-V2 param is {} *****'.format(all_param, total_param))
return model
def get_model_deprecated(model_args, task_type: TaskType, config: AutoConfig, fix_bert: bool = False):
if model_args.prefix:
config.hidden_dropout_prob = model_args.hidden_dropout_prob
config.pre_seq_len = model_args.pre_seq_len
config.prefix_projection = model_args.prefix_projection
config.prefix_hidden_size = model_args.prefix_hidden_size
if task_type == TaskType.TOKEN_CLASSIFICATION:
from model.token_classification import BertPrefixModel, RobertaPrefixModel, DebertaPrefixModel, DebertaV2PrefixModel
elif task_type == TaskType.SEQUENCE_CLASSIFICATION:
from model.sequence_classification import BertPrefixModel, RobertaPrefixModel, DebertaPrefixModel, DebertaV2PrefixModel
elif task_type == TaskType.QUESTION_ANSWERING:
from model.question_answering import BertPrefixModel, RobertaPrefixModel, DebertaPrefixModel, DebertaV2PrefixModel
elif task_type == TaskType.MULTIPLE_CHOICE:
from model.multiple_choice import BertPrefixModel
if config.model_type == "bert":
model = BertPrefixModel.from_pretrained(
model_args.model_name_or_path,
config=config,
revision=model_args.model_revision,
)
elif config.model_type == "roberta":
model = RobertaPrefixModel.from_pretrained(
model_args.model_name_or_path,
config=config,
revision=model_args.model_revision,
)
elif config.model_type == "deberta":
model = DebertaPrefixModel.from_pretrained(
model_args.model_name_or_path,
config=config,
revision=model_args.model_revision,
)
elif config.model_type == "deberta-v2":
model = DebertaV2PrefixModel.from_pretrained(
model_args.model_name_or_path,
config=config,
revision=model_args.model_revision,
)
else:
raise NotImplementedError
elif model_args.prompt:
config.pre_seq_len = model_args.pre_seq_len
from model.sequence_classification import BertPromptModel, RobertaPromptModel
if config.model_type == "bert":
model = BertPromptModel.from_pretrained(
model_args.model_name_or_path,
config=config,
revision=model_args.model_revision,
)
elif config.model_type == "roberta":
model = RobertaPromptModel.from_pretrained(
model_args.model_name_or_path,
config=config,
revision=model_args.model_revision,
)
else:
raise NotImplementedError
else:
if task_type == TaskType.TOKEN_CLASSIFICATION:
model = AutoModelForTokenClassification.from_pretrained(
model_args.model_name_or_path,
config=config,
revision=model_args.model_revision,
)
elif task_type == TaskType.SEQUENCE_CLASSIFICATION:
model = AutoModelForSequenceClassification.from_pretrained(
model_args.model_name_or_path,
config=config,
revision=model_args.model_revision,
)
elif task_type == TaskType.QUESTION_ANSWERING:
model = AutoModelForQuestionAnswering.from_pretrained(
model_args.model_name_or_path,
config=config,
revision=model_args.model_revision,
)
elif task_type == TaskType.MULTIPLE_CHOICE:
model = AutoModelForMultipleChoice.from_pretrained(
model_args.model_name_or_path,
config=config,
revision=model_args.model_revision,
)
bert_param = 0
if fix_bert:
if config.model_type == "bert":
for param in model.bert.parameters():
param.requires_grad = False
for _, param in model.bert.named_parameters():
bert_param += param.numel()
elif config.model_type == "roberta":
for param in model.roberta.parameters():
param.requires_grad = False
for _, param in model.roberta.named_parameters():
bert_param += param.numel()
elif config.model_type == "deberta":
for param in model.deberta.parameters():
param.requires_grad = False
for _, param in model.deberta.named_parameters():
bert_param += param.numel()
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
total_param = all_param - bert_param
print('***** total param is {} *****'.format(total_param))
return model