sunit333's picture
Upload 63 files
d08dd00 verified
raw
history blame
No virus
6.56 kB
# coding=utf-8
# Copyright 2018 The Google AI Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Exports a minimal TF-Hub module for ALBERT models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import app
from absl import flags
from albert import modeling
import tensorflow.compat.v1 as tf
import tensorflow_hub as hub
flags.DEFINE_string(
"albert_directory", None,
"The config json file corresponding to the pre-trained ALBERT model. "
"This specifies the model architecture.")
flags.DEFINE_string(
"checkpoint_name", "model.ckpt-best",
"Name of the checkpoint under albert_directory to be exported.")
flags.DEFINE_bool(
"do_lower_case", True,
"Whether to lower case the input text. Should be True for uncased "
"models and False for cased models.")
flags.DEFINE_bool(
"use_einsum", True,
"Whether to use tf.einsum or tf.reshape+tf.matmul for dense layers. Must "
"be set to False for TFLite compatibility.")
flags.DEFINE_string("export_path", None, "Path to the output TF-Hub module.")
FLAGS = flags.FLAGS
def gather_indexes(sequence_tensor, positions):
"""Gathers the vectors at the specific positions over a minibatch."""
sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3)
batch_size = sequence_shape[0]
seq_length = sequence_shape[1]
width = sequence_shape[2]
flat_offsets = tf.reshape(
tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
flat_positions = tf.reshape(positions + flat_offsets, [-1])
flat_sequence_tensor = tf.reshape(sequence_tensor,
[batch_size * seq_length, width])
output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
return output_tensor
def get_mlm_logits(model, albert_config, mlm_positions):
"""From run_pretraining.py."""
input_tensor = gather_indexes(model.get_sequence_output(), mlm_positions)
with tf.variable_scope("cls/predictions"):
# We apply one more non-linear transformation before the output layer.
# This matrix is not used after pre-training.
with tf.variable_scope("transform"):
input_tensor = tf.layers.dense(
input_tensor,
units=albert_config.embedding_size,
activation=modeling.get_activation(albert_config.hidden_act),
kernel_initializer=modeling.create_initializer(
albert_config.initializer_range))
input_tensor = modeling.layer_norm(input_tensor)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
output_bias = tf.get_variable(
"output_bias",
shape=[albert_config.vocab_size],
initializer=tf.zeros_initializer())
logits = tf.matmul(
input_tensor, model.get_embedding_table(), transpose_b=True)
logits = tf.nn.bias_add(logits, output_bias)
return logits
def module_fn(is_training):
"""Module function."""
input_ids = tf.placeholder(tf.int32, [None, None], "input_ids")
input_mask = tf.placeholder(tf.int32, [None, None], "input_mask")
segment_ids = tf.placeholder(tf.int32, [None, None], "segment_ids")
mlm_positions = tf.placeholder(tf.int32, [None, None], "mlm_positions")
albert_config_path = os.path.join(
FLAGS.albert_directory, "albert_config.json")
albert_config = modeling.AlbertConfig.from_json_file(albert_config_path)
model = modeling.AlbertModel(
config=albert_config,
is_training=is_training,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=segment_ids,
use_one_hot_embeddings=False,
use_einsum=FLAGS.use_einsum)
mlm_logits = get_mlm_logits(model, albert_config, mlm_positions)
vocab_model_path = os.path.join(FLAGS.albert_directory, "30k-clean.model")
vocab_file_path = os.path.join(FLAGS.albert_directory, "30k-clean.vocab")
config_file = tf.constant(
value=albert_config_path, dtype=tf.string, name="config_file")
vocab_model = tf.constant(
value=vocab_model_path, dtype=tf.string, name="vocab_model")
# This is only for visualization purpose.
vocab_file = tf.constant(
value=vocab_file_path, dtype=tf.string, name="vocab_file")
# By adding `config_file, vocab_model and vocab_file`
# to the ASSET_FILEPATHS collection, TF-Hub will
# rewrite this tensor so that this asset is portable.
tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, config_file)
tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, vocab_model)
tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, vocab_file)
hub.add_signature(
name="tokens",
inputs=dict(
input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids),
outputs=dict(
sequence_output=model.get_sequence_output(),
pooled_output=model.get_pooled_output()))
hub.add_signature(
name="mlm",
inputs=dict(
input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
mlm_positions=mlm_positions),
outputs=dict(
sequence_output=model.get_sequence_output(),
pooled_output=model.get_pooled_output(),
mlm_logits=mlm_logits))
hub.add_signature(
name="tokenization_info",
inputs={},
outputs=dict(
vocab_file=vocab_model,
do_lower_case=tf.constant(FLAGS.do_lower_case)))
def main(_):
tags_and_args = []
for is_training in (True, False):
tags = set()
if is_training:
tags.add("train")
tags_and_args.append((tags, dict(is_training=is_training)))
spec = hub.create_module_spec(module_fn, tags_and_args=tags_and_args)
checkpoint_path = os.path.join(FLAGS.albert_directory, FLAGS.checkpoint_name)
tf.logging.info("Using checkpoint {}".format(checkpoint_path))
spec.export(FLAGS.export_path, checkpoint_path=checkpoint_path)
if __name__ == "__main__":
flags.mark_flag_as_required("albert_directory")
flags.mark_flag_as_required("export_path")
app.run(main)