IE101TW / models /__init__.py
DeepLearning101's picture
第一次測試佈署更新
08f4077
raw
history blame
No virus
12.3 kB
# -*- coding: utf-8 -*-
# @Time : 2021/12/6 3:35 下午
# @Author : JianingWang
# @File : __init__.py
# from models.chid_mlm import BertForChidMLM
from models.multiple_choice.duma import BertDUMAForMultipleChoice, AlbertDUMAForMultipleChoice, MegatronDumaForMultipleChoice
from models.span_extraction.global_pointer import BertForEffiGlobalPointer, RobertaForEffiGlobalPointer, RoformerForEffiGlobalPointer, MegatronForEffiGlobalPointer
from transformers import AutoModelForTokenClassification, AutoModelForSequenceClassification, AutoModelForMaskedLM, AutoModelForMultipleChoice, BertTokenizer, \
AutoModelForQuestionAnswering, AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers.models.roformer import RoFormerTokenizer
from transformers.models.bert import BertTokenizerFast, BertForTokenClassification, BertTokenizer
from transformers.models.roberta.tokenization_roberta import RobertaTokenizer
from transformers.models.gpt2.tokenization_gpt2_fast import GPT2TokenizerFast
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from transformers.models.bart.tokenization_bart import BartTokenizer
from transformers.models.t5.tokenization_t5 import T5Tokenizer
from transformers.models.plbart.tokenization_plbart import PLBartTokenizer
# from models.deberta import DebertaV2ForMultipleChoice, DebertaForMultipleChoice
# from models.fengshen.models.longformer import LongformerForMultipleChoice
from models.kg import BertForPretrainWithKG, BertForPretrainWithKGV2
from models.language_modeling.mlm import BertForMaskedLM, RobertaForMaskedLM, AlbertForMaskedLM, RoFormerForMaskedLM
# from models.sequence_classification.classification import build_cls_model
from models.multiple_choice.multiple_choice_tag import BertForTagMultipleChoice, RoFormerForTagMultipleChoice, MegatronBertForTagMultipleChoice
from models.multiple_choice.multiple_choice import MegatronBertForMultipleChoice, MegatronBertRDropForMultipleChoice
from models.semeval7 import DebertaV2ForSemEval7MultiTask
from models.sequence_matching.fusion_siamese import BertForFusionSiamese, BertForWSC
# from roformer import RoFormerForTokenClassification, RoFormerForSequenceClassification
from models.fewshot_learning.span_proto import SpanProto
from models.fewshot_learning.token_proto import TokenProto
from models.sequence_labeling.head_token_cls import (
BertSoftmaxForSequenceLabeling, BertCrfForSequenceLabeling,
RobertaSoftmaxForSequenceLabeling, RobertaCrfForSequenceLabeling,
AlbertSoftmaxForSequenceLabeling, AlbertCrfForSequenceLabeling,
MegatronBertSoftmaxForSequenceLabeling, MegatronBertCrfForSequenceLabeling,
)
from models.span_extraction.span_for_ner import BertSpanForNer, RobertaSpanForNer, AlbertSpanForNer, MegatronBertSpanForNer
from models.language_modeling.mlm import BertForMaskedLM
from models.language_modeling.kpplm import BertForWikiKGPLM, RoBertaKPPLMForProcessedWikiKGPLM, DeBertaKPPLMForProcessedWikiKGPLM
from models.language_modeling.causal_lm import GPT2ForCausalLM
from models.sequence_classification.head_cls import (
BertForSequenceClassification, BertPrefixForSequenceClassification,
BertPtuningForSequenceClassification, BertAdapterForSequenceClassification,
RobertaForSequenceClassification, RobertaPrefixForSequenceClassification,
RobertaPtuningForSequenceClassification,RobertaAdapterForSequenceClassification,
BartForSequenceClassification, GPT2ForSequenceClassification
)
from models.sequence_classification.masked_prompt_cls import (
PromptBertForSequenceClassification, PromptBertPtuningForSequenceClassification,
PromptBertPrefixForSequenceClassification, PromptBertAdapterForSequenceClassification,
PromptRobertaForSequenceClassification, PromptRobertaPtuningForSequenceClassification,
PromptRobertaPrefixForSequenceClassification, PromptRobertaAdapterForSequenceClassification
)
from models.sequence_classification.causal_prompt_cls import PromptGPT2ForSequenceClassification
from models.code.code_classification import (
RobertaForCodeClassification, CodeBERTForCodeClassification,
GraphCodeBERTForCodeClassification, PLBARTForCodeClassification, CodeT5ForCodeClassification
)
from models.code.code_generation import (
PLBARTForCodeGeneration
)
from models.reinforcement_learning.actor import CausalActor
from models.reinforcement_learning.critic import AutoModelCritic
from models.reinforcement_learning.reward_model import (
RobertaForReward, GPT2ForReward
)
# Models for pre-training
PRETRAIN_MODEL_CLASSES = {
"mlm": {
"bert": BertForMaskedLM,
"roberta": RobertaForMaskedLM,
"albert": AlbertForMaskedLM,
"roformer": RoFormerForMaskedLM,
},
"auto_mlm": AutoModelForMaskedLM,
"causal_lm": {
"gpt2": GPT2ForCausalLM,
"bart": None,
"t5": None,
"llama": None
},
"auto_causal_lm": AutoModelForCausalLM
}
CLASSIFICATION_MODEL_CLASSES = {
"auto_cls": AutoModelForSequenceClassification, # huggingface cls
"classification": AutoModelForSequenceClassification, # huggingface cls
"head_cls": {
"bert": BertForSequenceClassification,
"roberta": RobertaForSequenceClassification,
"bart": BartForSequenceClassification,
"gpt2": GPT2ForSequenceClassification
}, # use standard fine-tuning head for cls, e.g., bert+mlp
"head_prefix_cls": {
"bert": BertPrefixForSequenceClassification,
"roberta": RobertaPrefixForSequenceClassification,
}, # use standard fine-tuning head with prefix-tuning technique for cls, e.g., bert+mlp
"head_ptuning_cls": {
"bert": BertPtuningForSequenceClassification,
"roberta": RobertaPtuningForSequenceClassification,
}, # use standard fine-tuning head with p-tuning technique for cls, e.g., bert+mlp
"head_adapter_cls": {
"bert": BertAdapterForSequenceClassification,
"roberta": RobertaAdapterForSequenceClassification,
}, # use standard fine-tuning head with adapter-tuning technique for cls, e.g., bert+mlp
"masked_prompt_cls": {
"bert": PromptBertForSequenceClassification,
"roberta": PromptRobertaForSequenceClassification,
# "deberta": PromptDebertaForSequenceClassification,
# "deberta-v2": PromptDebertav2ForSequenceClassification,
}, # use masked lm head technique for prompt-based cls, e.g., bert+mlm
"masked_prompt_prefix_cls": {
"bert": PromptBertPrefixForSequenceClassification,
"roberta": PromptRobertaPrefixForSequenceClassification,
# "deberta": PromptDebertaPrefixForSequenceClassification,
# "deberta-v2": PromptDebertav2PrefixForSequenceClassification,
}, # use masked lm head with prefix-tuning technique for prompt-based cls, e.g., bert+mlm
"masked_prompt_ptuning_cls": {
"bert": PromptBertPtuningForSequenceClassification,
"roberta": PromptRobertaPtuningForSequenceClassification,
# "deberta": PromptDebertaPtuningForSequenceClassification,
# "deberta-v2": PromptDebertav2PtuningForSequenceClassification,
}, # use masked lm head with p-tuning technique for prompt-based cls, e.g., bert+mlm
"masked_prompt_adapter_cls": {
"bert": PromptBertAdapterForSequenceClassification,
"roberta": PromptRobertaAdapterForSequenceClassification,
}, # use masked lm head with adapter-tuning technique for prompt-based cls, e.g., bert+mlm
"causal_prompt_cls": {
"gpt2": PromptGPT2ForSequenceClassification,
"bart": None,
"t5": None,
}, # use causal lm head for prompt-tuning, e.g., gpt2+lm
}
TOKEN_CLASSIFICATION_MODEL_CLASSES = {
"auto_token_cls": AutoModelForTokenClassification,
"head_softmax_token_cls": {
"bert": BertSoftmaxForSequenceLabeling,
"roberta": RobertaSoftmaxForSequenceLabeling,
"albert": AlbertSoftmaxForSequenceLabeling,
"megatron": MegatronBertSoftmaxForSequenceLabeling,
},
"head_crf_token_cls": {
"bert": BertCrfForSequenceLabeling,
"roberta": RobertaCrfForSequenceLabeling,
"albert": AlbertCrfForSequenceLabeling,
"megatron": MegatronBertCrfForSequenceLabeling,
}
}
SPAN_EXTRACTION_MODEL_CLASSES = {
"global_pointer": {
"bert": BertForEffiGlobalPointer,
"roberta": RobertaForEffiGlobalPointer,
"roformer": RoformerForEffiGlobalPointer,
"megatronbert": MegatronForEffiGlobalPointer
},
}
FEWSHOT_MODEL_CLASSES = {
"sequence_proto": None,
"span_proto": SpanProto,
"token_proto": TokenProto,
}
CODE_MODEL_CLASSES = {
"code_cls": {
"roberta": RobertaForCodeClassification,
"codebert": CodeBERTForCodeClassification,
"graphcodebert": GraphCodeBERTForCodeClassification,
"codet5": CodeT5ForCodeClassification,
"plbart": PLBARTForCodeClassification,
},
"code_generation": {
# "roberta": RobertaForCodeGeneration,
# "codebert": BertForCodeGeneration,
# "graphcodebert": BertForCodeGeneration,
# "codet5": T5ForCodeGeneration,
"plbart": PLBARTForCodeGeneration,
},
}
REINFORCEMENT_MODEL_CLASSES = {
"causal_actor": CausalActor,
"auto_critic": AutoModelCritic,
"rl_reward": {
"roberta": RobertaForReward,
"gpt2": GPT2ForReward,
"gpt-neo": None,
"opt": None,
"llama": None,
}
}
# task_type 负责对应model类型
OTHER_MODEL_CLASSES = {
# sequence labeling
"bert_span_ner": BertSpanForNer,
"roberta_span_ner": RobertaSpanForNer,
"albert_span_ner": AlbertSpanForNer,
"megatronbert_span_ner": MegatronBertSpanForNer,
# sequence matching
"fusion_siamese": BertForFusionSiamese,
# multiple choice
"multi_choice": AutoModelForMultipleChoice,
"multi_choice_megatron": MegatronBertForMultipleChoice,
"multi_choice_megatron_rdrop": MegatronBertRDropForMultipleChoice,
"megatron_multi_choice_tag": MegatronBertForTagMultipleChoice,
"roformer_multi_choice_tag": RoFormerForTagMultipleChoice,
"multi_choice_tag": BertForTagMultipleChoice,
"duma": BertDUMAForMultipleChoice,
"duma_albert": AlbertDUMAForMultipleChoice,
"duma_megatron": MegatronDumaForMultipleChoice,
# language modeling
# "bert_mlm_acc": BertForMaskedLMWithACC,
# "roformer_mlm_acc": RoFormerForMaskedLMWithACC,
"bert_pretrain_kg": BertForPretrainWithKG,
"bert_pretrain_kg_v2": BertForPretrainWithKGV2,
"kpplm_roberta": RoBertaKPPLMForProcessedWikiKGPLM,
"kpplm_deberta": DeBertaKPPLMForProcessedWikiKGPLM,
# other
"clue_wsc": BertForWSC,
"semeval7multitask": DebertaV2ForSemEval7MultiTask,
# "debertav2_multi_choice": DebertaV2ForMultipleChoice,
# "deberta_multi_choice": DebertaForMultipleChoice,
# "qa": AutoModelForQuestionAnswering,
# "roformer_cls": RoFormerForSequenceClassification,
# "roformer_ner": RoFormerForTokenClassification,
# "fensheng_multi_choice": LongformerForMultipleChoice,
# "chid_mlm": BertForChidMLM,
}
# MODEL_CLASSES = dict(list(PRETRAIN_MODEL_CLASSES.items()) + list(OTHER_MODEL_CLASSES.items()))
MODEL_CLASSES_LIST = [
PRETRAIN_MODEL_CLASSES,
CLASSIFICATION_MODEL_CLASSES,
TOKEN_CLASSIFICATION_MODEL_CLASSES,
SPAN_EXTRACTION_MODEL_CLASSES,
FEWSHOT_MODEL_CLASSES,
CODE_MODEL_CLASSES,
REINFORCEMENT_MODEL_CLASSES,
OTHER_MODEL_CLASSES,
]
MODEL_CLASSES = dict()
for model_class in MODEL_CLASSES_LIST:
MODEL_CLASSES = dict(list(MODEL_CLASSES.items()) + list(model_class.items()))
# model_type 负责对应tokenizer
TOKENIZER_CLASSES = {
# for natural language processing
"auto": AutoTokenizer,
"bert": BertTokenizerFast,
"roberta": RobertaTokenizer,
"wobert": RoFormerTokenizer,
"roformer": RoFormerTokenizer,
"bigbird": BertTokenizerFast,
"erlangshen": BertTokenizerFast,
"deberta": BertTokenizer,
"roformer_v2": BertTokenizerFast,
"gpt2": GPT2Tokenizer,
"megatronbert": BertTokenizerFast,
"bart": BartTokenizer,
"t5": T5Tokenizer,
# for programming language processing
"codebert": RobertaTokenizer,
"graphcodebert": RobertaTokenizer,
"codet5": RobertaTokenizer,
"plbart": PLBartTokenizer
}