aapot
commited on
Commit
·
2670ecc
1
Parent(s):
d5fd16e
Add pytorch model
Browse files- config.json +2 -0
- flax_model.msgpack +2 -2
- flax_model_to_pytorch.py +20 -0
- pytorch_model.bin +3 -0
config.json
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
{
|
|
|
2 |
"architectures": [
|
3 |
"RobertaForMaskedLM"
|
4 |
],
|
@@ -19,6 +20,7 @@
|
|
19 |
"num_hidden_layers": 24,
|
20 |
"pad_token_id": 1,
|
21 |
"position_embedding_type": "absolute",
|
|
|
22 |
"transformers_version": "4.10.0.dev0",
|
23 |
"type_vocab_size": 1,
|
24 |
"use_cache": true,
|
|
|
1 |
{
|
2 |
+
"_name_or_path": "./",
|
3 |
"architectures": [
|
4 |
"RobertaForMaskedLM"
|
5 |
],
|
|
|
20 |
"num_hidden_layers": 24,
|
21 |
"pad_token_id": 1,
|
22 |
"position_embedding_type": "absolute",
|
23 |
+
"torch_dtype": "float32",
|
24 |
"transformers_version": "4.10.0.dev0",
|
25 |
"type_vocab_size": 1,
|
26 |
"use_cache": true,
|
flax_model.msgpack
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7cf30f16ac72e048be2b0ad47ce76fdf2efcb13b5346dcf8a7d20d633848f7ac
|
3 |
+
size 1421662309
|
flax_model_to_pytorch.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import RobertaForMaskedLM, FlaxRobertaForMaskedLM, AutoTokenizer
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import jax
|
5 |
+
import jax.numpy as jnp
|
6 |
+
jax.config.update('jax_platform_name', 'cpu')
|
7 |
+
MODEL_PATH = "./"
|
8 |
+
model = FlaxRobertaForMaskedLM.from_pretrained(MODEL_PATH)
|
9 |
+
def to_f32(t):
|
10 |
+
return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
|
11 |
+
model.params = to_f32(model.params)
|
12 |
+
model.save_pretrained(MODEL_PATH)
|
13 |
+
pt_model = RobertaForMaskedLM.from_pretrained(MODEL_PATH, from_flax=True).to('cpu')
|
14 |
+
input_ids = np.asarray(2 * [128 * [0]], dtype=np.int32)
|
15 |
+
input_ids_pt = torch.tensor(input_ids)
|
16 |
+
logits_pt = pt_model(input_ids_pt).logits
|
17 |
+
print(logits_pt)
|
18 |
+
logits_fx = model(input_ids).logits
|
19 |
+
print(logits_fx)
|
20 |
+
pt_model.save_pretrained(MODEL_PATH)
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:190578a5782162752e1c2aaf1b04ef8b3db300a245a57746ffaea3f22db44963
|
3 |
+
size 1421780139
|