gpt2-medium-persian / src /convert_flax_to_tf.py
m3hrdadfi's picture
Add pytorch, tf version
c5a9149
raw
history blame
647 Bytes
import torch
import numpy as np
import jax
import jax.numpy as jnp
from transformers import AutoTokenizer
from transformers import GPT2LMHeadModel
from transformers import TFGPT2LMHeadModel
tokenizer = AutoTokenizer.from_pretrained("../")
tokenizer.pad_token = tokenizer.eos_token
model_pt = GPT2LMHeadModel.from_pretrained("./pt")
model_tf = TFGPT2LMHeadModel.from_pretrained("./pt", from_pt=True)
model_tf.save_pretrained("./tf")
input_ids = np.asarray(2 * [128 * [0]], dtype=np.int32)
input_ids_pt = torch.tensor(input_ids)
logits_pt = model_pt(input_ids_pt).logits
print(logits_pt)
logits_tf = model_tf(input_ids).logits
print(logits_tf)