sunit333's picture
Upload 63 files
d08dd00 verified
raw
history blame
No virus
6.03 kB
# coding=utf-8
# Copyright 2018 The Google AI Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Exports a minimal module for ALBERT models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import app
from absl import flags
from albert import modeling
import tensorflow.compat.v1 as tf
flags.DEFINE_string(
"albert_directory", None,
"The config json file corresponding to the pre-trained ALBERT model. "
"This specifies the model architecture.")
flags.DEFINE_string(
"checkpoint_name", "model.ckpt-best",
"Name of the checkpoint under albert_directory to be exported.")
flags.DEFINE_bool(
"do_lower_case", True,
"Whether to lower case the input text. Should be True for uncased "
"models and False for cased models.")
flags.DEFINE_string("export_path", None, "Path to the output module.")
FLAGS = flags.FLAGS
def gather_indexes(sequence_tensor, positions):
"""Gathers the vectors at the specific positions over a minibatch."""
sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3)
batch_size = sequence_shape[0]
seq_length = sequence_shape[1]
width = sequence_shape[2]
flat_offsets = tf.reshape(
tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
flat_positions = tf.reshape(positions + flat_offsets, [-1])
flat_sequence_tensor = tf.reshape(sequence_tensor,
[batch_size * seq_length, width])
output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
return output_tensor
def get_mlm_logits(input_tensor, albert_config, mlm_positions, output_weights):
"""From run_pretraining.py."""
input_tensor = gather_indexes(input_tensor, mlm_positions)
with tf.variable_scope("cls/predictions"):
# We apply one more non-linear transformation before the output layer.
# This matrix is not used after pre-training.
with tf.variable_scope("transform"):
input_tensor = tf.layers.dense(
input_tensor,
units=albert_config.embedding_size,
activation=modeling.get_activation(albert_config.hidden_act),
kernel_initializer=modeling.create_initializer(
albert_config.initializer_range))
input_tensor = modeling.layer_norm(input_tensor)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
output_bias = tf.get_variable(
"output_bias",
shape=[albert_config.vocab_size],
initializer=tf.zeros_initializer())
logits = tf.matmul(
input_tensor, output_weights, transpose_b=True)
logits = tf.nn.bias_add(logits, output_bias)
return logits
def get_sentence_order_logits(input_tensor, albert_config):
"""Get loss and log probs for the next sentence prediction."""
# Simple binary classification. Note that 0 is "next sentence" and 1 is
# "random sentence". This weight matrix is not used after pre-training.
with tf.variable_scope("cls/seq_relationship"):
output_weights = tf.get_variable(
"output_weights",
shape=[2, albert_config.hidden_size],
initializer=modeling.create_initializer(
albert_config.initializer_range))
output_bias = tf.get_variable(
"output_bias", shape=[2], initializer=tf.zeros_initializer())
logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
logits = tf.nn.bias_add(logits, output_bias)
return logits
def build_model(sess):
"""Module function."""
input_ids = tf.placeholder(tf.int32, [None, None], "input_ids")
input_mask = tf.placeholder(tf.int32, [None, None], "input_mask")
segment_ids = tf.placeholder(tf.int32, [None, None], "segment_ids")
mlm_positions = tf.placeholder(tf.int32, [None, None], "mlm_positions")
albert_config_path = os.path.join(
FLAGS.albert_directory, "albert_config.json")
albert_config = modeling.AlbertConfig.from_json_file(albert_config_path)
model = modeling.AlbertModel(
config=albert_config,
is_training=False,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=segment_ids,
use_one_hot_embeddings=False)
get_mlm_logits(model.get_sequence_output(), albert_config,
mlm_positions, model.get_embedding_table())
get_sentence_order_logits(model.get_pooled_output(), albert_config)
checkpoint_path = os.path.join(FLAGS.albert_directory, FLAGS.checkpoint_name)
tvars = tf.trainable_variables()
(assignment_map, initialized_variable_names
) = modeling.get_assignment_map_from_checkpoint(tvars, checkpoint_path)
tf.logging.info("**** Trainable Variables ****")
for var in tvars:
init_string = ""
if var.name in initialized_variable_names:
init_string = ", *INIT_FROM_CKPT*"
tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
init_string)
tf.train.init_from_checkpoint(checkpoint_path, assignment_map)
init = tf.global_variables_initializer()
sess.run(init)
return sess
def main(_):
sess = tf.Session()
tf.train.get_or_create_global_step()
sess = build_model(sess)
my_vars = []
for var in tf.global_variables():
if "lamb_v" not in var.name and "lamb_m" not in var.name:
my_vars.append(var)
saver = tf.train.Saver(my_vars)
saver.save(sess, FLAGS.export_path)
if __name__ == "__main__":
flags.mark_flag_as_required("albert_directory")
flags.mark_flag_as_required("export_path")
app.run(main)