ameerazam08 commited on
Commit
ebc6d95
·
1 Parent(s): c6a61ed

Create export_h5.py

Browse files
Files changed (1) hide show
  1. export_h5.py +12 -0
export_h5.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import TFGPT2LMHeadModel, GPT2Config, GPT2LMHeadModel
2
+
3
+ # Load your trained PyTorch model
4
+ pytorch_model_path = "trained_path"
5
+ config = GPT2Config.from_pretrained(pytorch_model_path)
6
+ pytorch_model = GPT2LMHeadModel.from_pretrained(pytorch_model_path, config=config,from_tf=True)
7
+
8
+ # Convert to TensorFlow model
9
+ tf_model = TFGPT2LMHeadModel.from_pretrained(pytorch_model_path, from_pt=True, config=config)
10
+
11
+ # Save the TensorFlow model
12
+ tf_model.save_pretrained(pytorch_model_path) # This will generate the tf_model.h5 file in the directory