t5-v1.1-base-dutch-cnn-test / flax_to_pytorch.py
yhavinga's picture
Update flax to pytorch script
332e951
raw
history blame
291 Bytes
from transformers import T5ForConditionalGeneration, TFT5ForConditionalGeneration
pt_model = T5ForConditionalGeneration.from_pretrained(".", from_flax=True)
pt_model.save_pretrained(".")
tf_model = TFT5ForConditionalGeneration.from_pretrained(".", from_pt=True)
tf_model.save_pretrained(".")