File size: 306 Bytes
49a91ba |
1 2 3 4 5 6 |
from transformers.modeling_flax_pytorch_utils import load_flax_checkpoint_in_pytorch_model
from transformers import GPT2_Config, GPT2_Model
config = GPT2_Config.from_pretrained("./")
model = GPT2_Model(config)
load_flax_checkpoint_in_pytorch_model(model, "./flax_model.msgpack")
model.save_pretrained("./") |