Spaces:
Sleeping
Sleeping
from enum import Enum | |
import torch | |
from .token_classification import ( | |
BertPrefixForTokenClassification, | |
RobertaPrefixForTokenClassification, | |
DebertaPrefixForTokenClassification, | |
DebertaV2PrefixForTokenClassification | |
) | |
from .sequence_classification import ( | |
BertPrefixForSequenceClassification, | |
BertPromptForSequenceClassification, | |
RobertaPrefixForSequenceClassification, | |
RobertaPromptForSequenceClassification, | |
DebertaPrefixForSequenceClassification, | |
GPT2PrefixForSequenceClassification, | |
GPT2PromptForSequenceClassification | |
) | |
from .question_answering import ( | |
BertPrefixForQuestionAnswering, | |
RobertaPrefixModelForQuestionAnswering, | |
DebertaPrefixModelForQuestionAnswering | |
) | |
from .multiple_choice import ( | |
BertPrefixForMultipleChoice, | |
RobertaPrefixForMultipleChoice, | |
DebertaPrefixForMultipleChoice, | |
BertPromptForMultipleChoice, | |
RobertaPromptForMultipleChoice | |
) | |
from .sequence_causallm import ( | |
BertPromptForMaskedLM, | |
BertPrefixForMaskedLM, | |
RobertaPromptForMaskedLM, | |
RobertaPrefixForMaskedLM, | |
LlamaPromptForMaskedLM, | |
LlamaPrefixForMaskedLM, | |
OPTPrefixForMaskedLM, | |
OPTPromptForMaskedLM | |
) | |
from transformers import ( | |
AutoConfig, | |
AutoModelForTokenClassification, | |
AutoModelForSequenceClassification, | |
AutoModelForQuestionAnswering, | |
AutoModelForMultipleChoice | |
) | |
import torch.nn.functional as F | |
def get_loss(predict_logits, labels_ids): | |
labels_ids = labels_ids.to(predict_logits.device) | |
predict_logp = F.log_softmax(predict_logits, dim=-1) | |
target_logp = predict_logp.gather(-1, labels_ids) | |
target_logp = target_logp - 1e32 * labels_ids.eq(0) # Apply mask | |
target_logp = torch.logsumexp(target_logp, dim=-1) | |
return -target_logp | |
def use_grad(base_model, use_grad): | |
if use_grad: | |
for param in base_model.parameters(): | |
param.requires_grad = True | |
base_model.train() | |
else: | |
for param in base_model.parameters(): | |
param.requires_grad = False | |
base_model.eval() | |
def get_embeddings(model, config): | |
"""Returns the wordpiece embedding module.""" | |
base_model = getattr(model, config.model_type) | |
embeddings = base_model.embeddings.word_embeddings | |
return embeddings | |
class GradientStorage: | |
""" | |
This object stores the intermediate gradients of the output a the given PyTorch module, which | |
otherwise might not be retained. | |
""" | |
def __init__(self, module): | |
self._stored_gradient = None | |
module.register_backward_hook(self.hook) | |
def hook(self, module, grad_in, grad_out): | |
assert grad_out is not None | |
self._stored_gradient = grad_out[0] | |
def reset(self): | |
self._stored_gradient = None | |
def get(self): | |
return self._stored_gradient | |
class TaskType(Enum): | |
TOKEN_CLASSIFICATION = 1, | |
SEQUENCE_CLASSIFICATION = 2, | |
QUESTION_ANSWERING = 3, | |
MULTIPLE_CHOICE = 4 | |
PREFIX_MODELS = { | |
"bert": { | |
TaskType.TOKEN_CLASSIFICATION: BertPrefixForTokenClassification, | |
TaskType.SEQUENCE_CLASSIFICATION: BertPrefixForMaskedLM, #BertPrefixForSequenceClassification, | |
TaskType.QUESTION_ANSWERING: BertPrefixForQuestionAnswering, | |
TaskType.MULTIPLE_CHOICE: BertPrefixForMultipleChoice | |
}, | |
"roberta": { | |
TaskType.TOKEN_CLASSIFICATION: RobertaPrefixForTokenClassification, | |
TaskType.SEQUENCE_CLASSIFICATION: RobertaPrefixForMaskedLM, #RobertaPrefixForSequenceClassification, | |
TaskType.QUESTION_ANSWERING: RobertaPrefixModelForQuestionAnswering, | |
TaskType.MULTIPLE_CHOICE: RobertaPrefixForMultipleChoice, | |
}, | |
"deberta": { | |
TaskType.TOKEN_CLASSIFICATION: DebertaPrefixForTokenClassification, | |
TaskType.SEQUENCE_CLASSIFICATION: DebertaPrefixForSequenceClassification, | |
TaskType.QUESTION_ANSWERING: DebertaPrefixModelForQuestionAnswering, | |
TaskType.MULTIPLE_CHOICE: DebertaPrefixForMultipleChoice, | |
}, | |
"deberta-v2": { | |
TaskType.TOKEN_CLASSIFICATION: DebertaV2PrefixForTokenClassification, | |
TaskType.SEQUENCE_CLASSIFICATION: None, | |
TaskType.QUESTION_ANSWERING: None, | |
TaskType.MULTIPLE_CHOICE: None, | |
}, | |
"gpt2": { | |
TaskType.TOKEN_CLASSIFICATION: None, | |
TaskType.SEQUENCE_CLASSIFICATION: GPT2PrefixForSequenceClassification, | |
TaskType.QUESTION_ANSWERING: None, | |
TaskType.MULTIPLE_CHOICE: None, | |
}, | |
"llama": { | |
TaskType.TOKEN_CLASSIFICATION: None, | |
TaskType.SEQUENCE_CLASSIFICATION: LlamaPrefixForMaskedLM, | |
TaskType.QUESTION_ANSWERING: None, | |
TaskType.MULTIPLE_CHOICE: None, | |
}, | |
"opt": { | |
TaskType.TOKEN_CLASSIFICATION: None, | |
TaskType.SEQUENCE_CLASSIFICATION: OPTPrefixForMaskedLM, | |
TaskType.QUESTION_ANSWERING: None, | |
TaskType.MULTIPLE_CHOICE: None, | |
} | |
} | |
PROMPT_MODELS = { | |
"bert": { | |
TaskType.SEQUENCE_CLASSIFICATION: BertPromptForMaskedLM, #BertPromptForSequenceClassification, | |
TaskType.MULTIPLE_CHOICE: BertPromptForMultipleChoice | |
}, | |
"roberta": { | |
TaskType.SEQUENCE_CLASSIFICATION: RobertaPromptForMaskedLM, #RobertaPromptForSequenceClassification, | |
TaskType.MULTIPLE_CHOICE: RobertaPromptForMultipleChoice | |
}, | |
"gpt2": { | |
TaskType.SEQUENCE_CLASSIFICATION: GPT2PromptForSequenceClassification, | |
TaskType.MULTIPLE_CHOICE: None | |
}, | |
"llama": { | |
TaskType.TOKEN_CLASSIFICATION: None, | |
TaskType.SEQUENCE_CLASSIFICATION: LlamaPromptForMaskedLM, | |
TaskType.QUESTION_ANSWERING: None, | |
TaskType.MULTIPLE_CHOICE: None, | |
}, | |
"opt": { | |
TaskType.TOKEN_CLASSIFICATION: None, | |
TaskType.SEQUENCE_CLASSIFICATION: OPTPromptForMaskedLM, | |
TaskType.QUESTION_ANSWERING: None, | |
TaskType.MULTIPLE_CHOICE: None, | |
} | |
} | |
AUTO_MODELS = { | |
TaskType.TOKEN_CLASSIFICATION: AutoModelForTokenClassification, | |
TaskType.SEQUENCE_CLASSIFICATION: AutoModelForSequenceClassification, | |
TaskType.QUESTION_ANSWERING: AutoModelForQuestionAnswering, | |
TaskType.MULTIPLE_CHOICE: AutoModelForMultipleChoice, | |
} | |
def get_model(model_args, task_type: TaskType, config: AutoConfig, fix_bert: bool = False, tokenizer=None): | |
model_name_or_path = f'openlm-research/{model_args.model_name_or_path}' if "llama" in model_args.model_name_or_path else model_args.model_name_or_path | |
if model_args.prefix: | |
config.hidden_dropout_prob = model_args.hidden_dropout_prob | |
config.pre_seq_len = model_args.pre_seq_len | |
config.prefix_projection = model_args.prefix_projection | |
config.prefix_hidden_size = model_args.prefix_hidden_size | |
model_class = PREFIX_MODELS[config.model_type][task_type] | |
if "opt" in model_args.model_name_or_path: | |
model_name_or_path = f'facebook/{model_args.model_name_or_path}' | |
model = model_class.from_pretrained( | |
model_name_or_path, | |
config=config, | |
revision=model_args.model_revision, | |
trust_remote_code=True | |
) | |
elif "llama" in model_args.model_name_or_path: | |
model_name_or_path = f'openlm-research/{model_args.model_name_or_path}' | |
model = model_class.from_pretrained( | |
model_name_or_path, | |
config=config, | |
trust_remote_code=True, | |
torch_dtype=torch.float32, | |
device_map='auto', | |
) | |
else: | |
model = model_class.from_pretrained( | |
model_name_or_path, | |
config=config, | |
trust_remote_code=True, | |
revision=model_args.model_revision | |
) | |
elif model_args.prompt: | |
config.pre_seq_len = model_args.pre_seq_len | |
model_class = PROMPT_MODELS[config.model_type][task_type] | |
if "opt" in model_args.model_name_or_path: | |
model_name_or_path = f'facebook/opt-1.3b' | |
model = model_class.from_pretrained( | |
model_name_or_path, | |
config=config, | |
revision=model_args.model_revision, | |
trust_remote_code=True | |
) | |
elif "llama" in model_args.model_name_or_path: | |
model_name_or_path = f'openlm-research/{model_args.model_name_or_path}' | |
model = model_class.from_pretrained( | |
model_name_or_path, | |
config=config, | |
trust_remote_code=True, | |
torch_dtype=torch.float32, | |
device_map='auto', | |
) | |
else: | |
model = model_class.from_pretrained( | |
model_name_or_path, | |
config=config, | |
revision=model_args.model_revision, | |
trust_remote_code=True | |
) | |
else: | |
model_class = AUTO_MODELS[task_type] | |
model = model_class.from_pretrained( | |
model_name_or_path, | |
config=config, | |
revision=model_args.model_revision, | |
) | |
base_param = 0 | |
if fix_bert: | |
if config.model_type == "bert": | |
for param in model.bert.parameters(): | |
param.requires_grad = False | |
for _, param in model.bert.named_parameters(): | |
base_param += param.numel() | |
elif config.model_type == "roberta": | |
for param in model.roberta.parameters(): | |
param.requires_grad = False | |
for _, param in model.roberta.named_parameters(): | |
base_param += param.numel() | |
elif config.model_type == "deberta": | |
for param in model.deberta.parameters(): | |
param.requires_grad = False | |
for _, param in model.deberta.named_parameters(): | |
base_param += param.numel() | |
elif config.model_type == "gpt2": | |
for param in model.gpt2.parameters(): | |
param.requires_grad = False | |
for _, param in model.gpt2.named_parameters(): | |
base_param += param.numel() | |
all_param = 0 | |
for _, param in model.named_parameters(): | |
all_param += param.numel() | |
total_param = all_param - base_param | |
print('***** Backborn param:{:0.3f}M, P-Tuning-V2 param is {} *****'.format(all_param, total_param)) | |
return model | |
def get_model_deprecated(model_args, task_type: TaskType, config: AutoConfig, fix_bert: bool = False): | |
if model_args.prefix: | |
config.hidden_dropout_prob = model_args.hidden_dropout_prob | |
config.pre_seq_len = model_args.pre_seq_len | |
config.prefix_projection = model_args.prefix_projection | |
config.prefix_hidden_size = model_args.prefix_hidden_size | |
if task_type == TaskType.TOKEN_CLASSIFICATION: | |
from model.token_classification import BertPrefixModel, RobertaPrefixModel, DebertaPrefixModel, DebertaV2PrefixModel | |
elif task_type == TaskType.SEQUENCE_CLASSIFICATION: | |
from model.sequence_classification import BertPrefixModel, RobertaPrefixModel, DebertaPrefixModel, DebertaV2PrefixModel | |
elif task_type == TaskType.QUESTION_ANSWERING: | |
from model.question_answering import BertPrefixModel, RobertaPrefixModel, DebertaPrefixModel, DebertaV2PrefixModel | |
elif task_type == TaskType.MULTIPLE_CHOICE: | |
from model.multiple_choice import BertPrefixModel | |
if config.model_type == "bert": | |
model = BertPrefixModel.from_pretrained( | |
model_args.model_name_or_path, | |
config=config, | |
revision=model_args.model_revision, | |
) | |
elif config.model_type == "roberta": | |
model = RobertaPrefixModel.from_pretrained( | |
model_args.model_name_or_path, | |
config=config, | |
revision=model_args.model_revision, | |
) | |
elif config.model_type == "deberta": | |
model = DebertaPrefixModel.from_pretrained( | |
model_args.model_name_or_path, | |
config=config, | |
revision=model_args.model_revision, | |
) | |
elif config.model_type == "deberta-v2": | |
model = DebertaV2PrefixModel.from_pretrained( | |
model_args.model_name_or_path, | |
config=config, | |
revision=model_args.model_revision, | |
) | |
else: | |
raise NotImplementedError | |
elif model_args.prompt: | |
config.pre_seq_len = model_args.pre_seq_len | |
from model.sequence_classification import BertPromptModel, RobertaPromptModel | |
if config.model_type == "bert": | |
model = BertPromptModel.from_pretrained( | |
model_args.model_name_or_path, | |
config=config, | |
revision=model_args.model_revision, | |
) | |
elif config.model_type == "roberta": | |
model = RobertaPromptModel.from_pretrained( | |
model_args.model_name_or_path, | |
config=config, | |
revision=model_args.model_revision, | |
) | |
else: | |
raise NotImplementedError | |
else: | |
if task_type == TaskType.TOKEN_CLASSIFICATION: | |
model = AutoModelForTokenClassification.from_pretrained( | |
model_args.model_name_or_path, | |
config=config, | |
revision=model_args.model_revision, | |
) | |
elif task_type == TaskType.SEQUENCE_CLASSIFICATION: | |
model = AutoModelForSequenceClassification.from_pretrained( | |
model_args.model_name_or_path, | |
config=config, | |
revision=model_args.model_revision, | |
) | |
elif task_type == TaskType.QUESTION_ANSWERING: | |
model = AutoModelForQuestionAnswering.from_pretrained( | |
model_args.model_name_or_path, | |
config=config, | |
revision=model_args.model_revision, | |
) | |
elif task_type == TaskType.MULTIPLE_CHOICE: | |
model = AutoModelForMultipleChoice.from_pretrained( | |
model_args.model_name_or_path, | |
config=config, | |
revision=model_args.model_revision, | |
) | |
bert_param = 0 | |
if fix_bert: | |
if config.model_type == "bert": | |
for param in model.bert.parameters(): | |
param.requires_grad = False | |
for _, param in model.bert.named_parameters(): | |
bert_param += param.numel() | |
elif config.model_type == "roberta": | |
for param in model.roberta.parameters(): | |
param.requires_grad = False | |
for _, param in model.roberta.named_parameters(): | |
bert_param += param.numel() | |
elif config.model_type == "deberta": | |
for param in model.deberta.parameters(): | |
param.requires_grad = False | |
for _, param in model.deberta.named_parameters(): | |
bert_param += param.numel() | |
all_param = 0 | |
for _, param in model.named_parameters(): | |
all_param += param.numel() | |
total_param = all_param - bert_param | |
print('***** total param is {} *****'.format(total_param)) | |
return model | |