guillermoruiz commited on
Commit
b3e7614
·
verified ·
1 Parent(s): ceef720

Upload TFBilma

Browse files
Files changed (4) hide show
  1. config.json +1 -0
  2. configuration_bilma.py +6 -1
  3. modeling_bilma.py +16 -6
  4. tf_model.h5 +1 -1
config.json CHANGED
@@ -9,6 +9,7 @@
9
  },
10
  "hidden_dropout_prob": 0.1,
11
  "hidden_size": 512,
 
12
  "include_top": true,
13
  "model_type": "bilma",
14
  "num_attention_heads": 4,
 
9
  },
10
  "hidden_dropout_prob": 0.1,
11
  "hidden_size": 512,
12
+ "include_head": null,
13
  "include_top": true,
14
  "model_type": "bilma",
15
  "num_attention_heads": 4,
configuration_bilma.py CHANGED
@@ -6,7 +6,8 @@ class BilmaConfig(PretrainedConfig):
6
  def __init__(
7
  self,
8
  weights="MX",
9
- include_top=True,
 
10
  num_attention_heads: int = 4,
11
  num_hidden_layers: int = 2,
12
  seq_max_length: int = 280,
@@ -18,9 +19,12 @@ class BilmaConfig(PretrainedConfig):
18
  countries = ["MX"]
19
  if weights not in countries:
20
  raise ValueError(f"`weights` must be one of {countries}, got {weights}.")
 
 
21
  if weights is not None:
22
  self.weights = weights
23
  self.include_top = include_top
 
24
  self.num_attention_heads = 4
25
  self.num_hidden_layers = 2
26
  self.seq_max_length = 280
@@ -32,6 +36,7 @@ class BilmaConfig(PretrainedConfig):
32
 
33
  self.weights = weights
34
  self.include_top = include_top
 
35
  self.num_attention_heads = num_attention_heads
36
  self.num_hidden_layers = num_hidden_layers
37
  self.seq_max_length = seq_max_length
 
6
  def __init__(
7
  self,
8
  weights="MX",
9
+ include_top = True,
10
+ include_head = None,
11
  num_attention_heads: int = 4,
12
  num_hidden_layers: int = 2,
13
  seq_max_length: int = 280,
 
19
  countries = ["MX"]
20
  if weights not in countries:
21
  raise ValueError(f"`weights` must be one of {countries}, got {weights}.")
22
+ if include_head is not None and include_top == True:
23
+ raise ValueError(f"To include a head, 'include_top' must be False")
24
  if weights is not None:
25
  self.weights = weights
26
  self.include_top = include_top
27
+ self.include_head = include_head
28
  self.num_attention_heads = 4
29
  self.num_hidden_layers = 2
30
  self.seq_max_length = 280
 
36
 
37
  self.weights = weights
38
  self.include_top = include_top
39
+ self.include_head = include_head
40
  self.num_attention_heads = num_attention_heads
41
  self.num_hidden_layers = num_hidden_layers
42
  self.seq_max_length = seq_max_length
modeling_bilma.py CHANGED
@@ -9,7 +9,7 @@ from typing import Dict
9
  import re
10
  import unicodedata
11
 
12
- from .configuration_bilma import BilmaConfig
13
 
14
  # copied from preprocessing.py
15
  BLANK = ' '
@@ -47,7 +47,8 @@ class TFBilma(TFPreTrainedModel):
47
  ff_dim=config.hidden_size,
48
  vocab_size=config.vocab_size,
49
  rate=config.hidden_dropout_prob,
50
- include_top = config.include_top)
 
51
 
52
  @property
53
  def dummy_inputs(self) -> Dict[str, tf.Tensor]:
@@ -74,7 +75,10 @@ class TFBilma(TFPreTrainedModel):
74
  if self.include_top:
75
  output = {"logits":self.model(ins)}
76
  else:
77
- output = {"last_hidden_state":self.model(ins)}
 
 
 
78
  return output
79
 
80
  # copied from bilma_model.py
@@ -105,7 +109,7 @@ def accuracy_function(ignore_id=0):
105
  return tf.math.divide_no_nan(tf.reduce_sum(accuracies), tf.reduce_sum(mask))
106
  return acc_mlm
107
 
108
- def bilma(num_enc=6, embed_dim=300, max_length=50, num_heads=6, ff_dim=512, vocab_size=9739, rate=0.1, include_top=True):
109
  capt_inputs_ids = Input(shape=(max_length, ), name='input_ids')
110
  capt_embedding = Embedding(vocab_size, embed_dim, mask_zero=False, name="bilma/embedding")
111
  capt_inputs = capt_embedding(capt_inputs_ids)
@@ -115,9 +119,15 @@ def bilma(num_enc=6, embed_dim=300, max_length=50, num_heads=6, ff_dim=512, voca
115
  if include_top:
116
  fin_output = Dense(vocab_size, use_bias=True, name="bilma/dense_final")(enc_output)
117
  else:
118
- fin_output = enc_output
 
 
 
 
 
 
119
 
120
- caption_model = Model(inputs=capt_inputs_ids, outputs=[fin_output], name="bilma_model")
121
  return caption_model
122
 
123
  def load(model_file):
 
9
  import re
10
  import unicodedata
11
 
12
+ from configuration_bilma import BilmaConfig
13
 
14
  # copied from preprocessing.py
15
  BLANK = ' '
 
47
  ff_dim=config.hidden_size,
48
  vocab_size=config.vocab_size,
49
  rate=config.hidden_dropout_prob,
50
+ include_top = config.include_top,
51
+ include_head = config.include_head)
52
 
53
  @property
54
  def dummy_inputs(self) -> Dict[str, tf.Tensor]:
 
75
  if self.include_top:
76
  output = {"logits":self.model(ins)}
77
  else:
78
+ if self.include_head is None:
79
+ output = {"last_hidden_state":self.model(ins)}
80
+ else:
81
+ output = {"logits":self.model(ins)}
82
  return output
83
 
84
  # copied from bilma_model.py
 
109
  return tf.math.divide_no_nan(tf.reduce_sum(accuracies), tf.reduce_sum(mask))
110
  return acc_mlm
111
 
112
+ def bilma(num_enc=6, embed_dim=300, max_length=50, num_heads=6, ff_dim=512, vocab_size=9739, rate=0.1, include_top=True, include_head=None):
113
  capt_inputs_ids = Input(shape=(max_length, ), name='input_ids')
114
  capt_embedding = Embedding(vocab_size, embed_dim, mask_zero=False, name="bilma/embedding")
115
  capt_inputs = capt_embedding(capt_inputs_ids)
 
119
  if include_top:
120
  fin_output = Dense(vocab_size, use_bias=True, name="bilma/dense_final")(enc_output)
121
  else:
122
+ if self.include_head is None:
123
+ fin_output = enc_output
124
+ else:
125
+ x = enc_output
126
+ for i, m in enumerate(self.include_head[:-1]):
127
+ x = Dense(m, use_bias=True, activation="relu", name=f"bilma/dense_ex_{i}")(x)
128
+ fin_output = [Dense(self.include_head[-1], use_bias=True, name=f"bilma/dense_ex_final")(x), enc_output]
129
 
130
+ caption_model = Model(inputs=capt_inputs_ids, outputs=fin_output, name="bilma_model")
131
  return caption_model
132
 
133
  def load(model_file):
tf_model.h5 CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:bbfa589e471d9015d5ca64d2d212afa28da612a2ff8f2d93560fca1b03167afa
3
  size 156875820
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f83fdad7da418dac337cc4df40cb630f3145ff66b48188148e899214539e2db5
3
  size 156875820