|
from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast |
|
|
|
|
|
def get_tokenlizer(text_encoder_type): |
|
if not isinstance(text_encoder_type, str): |
|
|
|
if hasattr(text_encoder_type, "text_encoder_type"): |
|
text_encoder_type = text_encoder_type.text_encoder_type |
|
elif text_encoder_type.get("text_encoder_type", False): |
|
text_encoder_type = text_encoder_type.get("text_encoder_type") |
|
else: |
|
raise ValueError( |
|
"Unknown type of text_encoder_type: {}".format(type(text_encoder_type)) |
|
) |
|
print("final text_encoder_type: {}".format(text_encoder_type)) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(text_encoder_type) |
|
return tokenizer |
|
|
|
|
|
def get_pretrained_language_model(text_encoder_type): |
|
if text_encoder_type == "bert-base-uncased": |
|
return BertModel.from_pretrained(text_encoder_type) |
|
if text_encoder_type == "roberta-base": |
|
return RobertaModel.from_pretrained(text_encoder_type) |
|
raise ValueError("Unknown text_encoder_type {}".format(text_encoder_type)) |
|
|