Upload TFBilma
Browse files- config.json +1 -0
- configuration_bilma.py +6 -1
- modeling_bilma.py +16 -6
- tf_model.h5 +1 -1
config.json
CHANGED
@@ -9,6 +9,7 @@
|
|
9 |
},
|
10 |
"hidden_dropout_prob": 0.1,
|
11 |
"hidden_size": 512,
|
|
|
12 |
"include_top": true,
|
13 |
"model_type": "bilma",
|
14 |
"num_attention_heads": 4,
|
|
|
9 |
},
|
10 |
"hidden_dropout_prob": 0.1,
|
11 |
"hidden_size": 512,
|
12 |
+
"include_head": null,
|
13 |
"include_top": true,
|
14 |
"model_type": "bilma",
|
15 |
"num_attention_heads": 4,
|
configuration_bilma.py
CHANGED
@@ -6,7 +6,8 @@ class BilmaConfig(PretrainedConfig):
|
|
6 |
def __init__(
|
7 |
self,
|
8 |
weights="MX",
|
9 |
-
include_top=True,
|
|
|
10 |
num_attention_heads: int = 4,
|
11 |
num_hidden_layers: int = 2,
|
12 |
seq_max_length: int = 280,
|
@@ -18,9 +19,12 @@ class BilmaConfig(PretrainedConfig):
|
|
18 |
countries = ["MX"]
|
19 |
if weights not in countries:
|
20 |
raise ValueError(f"`weights` must be one of {countries}, got {weights}.")
|
|
|
|
|
21 |
if weights is not None:
|
22 |
self.weights = weights
|
23 |
self.include_top = include_top
|
|
|
24 |
self.num_attention_heads = 4
|
25 |
self.num_hidden_layers = 2
|
26 |
self.seq_max_length = 280
|
@@ -32,6 +36,7 @@ class BilmaConfig(PretrainedConfig):
|
|
32 |
|
33 |
self.weights = weights
|
34 |
self.include_top = include_top
|
|
|
35 |
self.num_attention_heads = num_attention_heads
|
36 |
self.num_hidden_layers = num_hidden_layers
|
37 |
self.seq_max_length = seq_max_length
|
|
|
6 |
def __init__(
|
7 |
self,
|
8 |
weights="MX",
|
9 |
+
include_top = True,
|
10 |
+
include_head = None,
|
11 |
num_attention_heads: int = 4,
|
12 |
num_hidden_layers: int = 2,
|
13 |
seq_max_length: int = 280,
|
|
|
19 |
countries = ["MX"]
|
20 |
if weights not in countries:
|
21 |
raise ValueError(f"`weights` must be one of {countries}, got {weights}.")
|
22 |
+
if include_head is not None and include_top == True:
|
23 |
+
raise ValueError(f"To include a head, 'include_top' must be False")
|
24 |
if weights is not None:
|
25 |
self.weights = weights
|
26 |
self.include_top = include_top
|
27 |
+
self.include_head = include_head
|
28 |
self.num_attention_heads = 4
|
29 |
self.num_hidden_layers = 2
|
30 |
self.seq_max_length = 280
|
|
|
36 |
|
37 |
self.weights = weights
|
38 |
self.include_top = include_top
|
39 |
+
self.include_head = include_head
|
40 |
self.num_attention_heads = num_attention_heads
|
41 |
self.num_hidden_layers = num_hidden_layers
|
42 |
self.seq_max_length = seq_max_length
|
modeling_bilma.py
CHANGED
@@ -9,7 +9,7 @@ from typing import Dict
|
|
9 |
import re
|
10 |
import unicodedata
|
11 |
|
12 |
-
from
|
13 |
|
14 |
# copied from preprocessing.py
|
15 |
BLANK = ' '
|
@@ -47,7 +47,8 @@ class TFBilma(TFPreTrainedModel):
|
|
47 |
ff_dim=config.hidden_size,
|
48 |
vocab_size=config.vocab_size,
|
49 |
rate=config.hidden_dropout_prob,
|
50 |
-
include_top = config.include_top
|
|
|
51 |
|
52 |
@property
|
53 |
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
|
@@ -74,7 +75,10 @@ class TFBilma(TFPreTrainedModel):
|
|
74 |
if self.include_top:
|
75 |
output = {"logits":self.model(ins)}
|
76 |
else:
|
77 |
-
|
|
|
|
|
|
|
78 |
return output
|
79 |
|
80 |
# copied from bilma_model.py
|
@@ -105,7 +109,7 @@ def accuracy_function(ignore_id=0):
|
|
105 |
return tf.math.divide_no_nan(tf.reduce_sum(accuracies), tf.reduce_sum(mask))
|
106 |
return acc_mlm
|
107 |
|
108 |
-
def bilma(num_enc=6, embed_dim=300, max_length=50, num_heads=6, ff_dim=512, vocab_size=9739, rate=0.1, include_top=True):
|
109 |
capt_inputs_ids = Input(shape=(max_length, ), name='input_ids')
|
110 |
capt_embedding = Embedding(vocab_size, embed_dim, mask_zero=False, name="bilma/embedding")
|
111 |
capt_inputs = capt_embedding(capt_inputs_ids)
|
@@ -115,9 +119,15 @@ def bilma(num_enc=6, embed_dim=300, max_length=50, num_heads=6, ff_dim=512, voca
|
|
115 |
if include_top:
|
116 |
fin_output = Dense(vocab_size, use_bias=True, name="bilma/dense_final")(enc_output)
|
117 |
else:
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
-
caption_model = Model(inputs=capt_inputs_ids, outputs=
|
121 |
return caption_model
|
122 |
|
123 |
def load(model_file):
|
|
|
9 |
import re
|
10 |
import unicodedata
|
11 |
|
12 |
+
from configuration_bilma import BilmaConfig
|
13 |
|
14 |
# copied from preprocessing.py
|
15 |
BLANK = ' '
|
|
|
47 |
ff_dim=config.hidden_size,
|
48 |
vocab_size=config.vocab_size,
|
49 |
rate=config.hidden_dropout_prob,
|
50 |
+
include_top = config.include_top,
|
51 |
+
include_head = config.include_head)
|
52 |
|
53 |
@property
|
54 |
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
|
|
|
75 |
if self.include_top:
|
76 |
output = {"logits":self.model(ins)}
|
77 |
else:
|
78 |
+
if self.include_head is None:
|
79 |
+
output = {"last_hidden_state":self.model(ins)}
|
80 |
+
else:
|
81 |
+
output = {"logits":self.model(ins)}
|
82 |
return output
|
83 |
|
84 |
# copied from bilma_model.py
|
|
|
109 |
return tf.math.divide_no_nan(tf.reduce_sum(accuracies), tf.reduce_sum(mask))
|
110 |
return acc_mlm
|
111 |
|
112 |
+
def bilma(num_enc=6, embed_dim=300, max_length=50, num_heads=6, ff_dim=512, vocab_size=9739, rate=0.1, include_top=True, include_head=None):
|
113 |
capt_inputs_ids = Input(shape=(max_length, ), name='input_ids')
|
114 |
capt_embedding = Embedding(vocab_size, embed_dim, mask_zero=False, name="bilma/embedding")
|
115 |
capt_inputs = capt_embedding(capt_inputs_ids)
|
|
|
119 |
if include_top:
|
120 |
fin_output = Dense(vocab_size, use_bias=True, name="bilma/dense_final")(enc_output)
|
121 |
else:
|
122 |
+
if self.include_head is None:
|
123 |
+
fin_output = enc_output
|
124 |
+
else:
|
125 |
+
x = enc_output
|
126 |
+
for i, m in enumerate(self.include_head[:-1]):
|
127 |
+
x = Dense(m, use_bias=True, activation="relu", name=f"bilma/dense_ex_{i}")(x)
|
128 |
+
fin_output = [Dense(self.include_head[-1], use_bias=True, name=f"bilma/dense_ex_final")(x), enc_output]
|
129 |
|
130 |
+
caption_model = Model(inputs=capt_inputs_ids, outputs=fin_output, name="bilma_model")
|
131 |
return caption_model
|
132 |
|
133 |
def load(model_file):
|
tf_model.h5
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 156875820
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f83fdad7da418dac337cc4df40cb630f3145ff66b48188148e899214539e2db5
|
3 |
size 156875820
|