tools: update conversion script
Browse files
convert_token_dropping_bert_original_tf2_checkpoint_to_pytorch.py
CHANGED
@@ -46,8 +46,8 @@ def convert_checkpoint_to_pytorch(tf_checkpoint_path: str, config_path: str, pyt
|
|
46 |
full_name = f"masked_lm/{name}/.ATTRIBUTES/VARIABLE_VALUE"
|
47 |
array = tf.train.load_variable(tf_checkpoint_path, full_name)
|
48 |
|
49 |
-
|
50 |
-
|
51 |
|
52 |
return torch.from_numpy(array)
|
53 |
|
|
|
46 |
full_name = f"masked_lm/{name}/.ATTRIBUTES/VARIABLE_VALUE"
|
47 |
array = tf.train.load_variable(tf_checkpoint_path, full_name)
|
48 |
|
49 |
+
if "kernel" in name:
|
50 |
+
array = array.transpose()
|
51 |
|
52 |
return torch.from_numpy(array)
|
53 |
|