# coding=utf-8 # Copyright 2018 The Google AI Team Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. r"""Exports a minimal module for ALBERT models.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import os from absl import app from absl import flags from albert import modeling import tensorflow.compat.v1 as tf flags.DEFINE_string( "albert_directory", None, "The config json file corresponding to the pre-trained ALBERT model. " "This specifies the model architecture.") flags.DEFINE_string( "checkpoint_name", "model.ckpt-best", "Name of the checkpoint under albert_directory to be exported.") flags.DEFINE_bool( "do_lower_case", True, "Whether to lower case the input text. Should be True for uncased " "models and False for cased models.") flags.DEFINE_string("export_path", None, "Path to the output module.") FLAGS = flags.FLAGS def gather_indexes(sequence_tensor, positions): """Gathers the vectors at the specific positions over a minibatch.""" sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3) batch_size = sequence_shape[0] seq_length = sequence_shape[1] width = sequence_shape[2] flat_offsets = tf.reshape( tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) flat_positions = tf.reshape(positions + flat_offsets, [-1]) flat_sequence_tensor = tf.reshape(sequence_tensor, [batch_size * seq_length, width]) output_tensor = tf.gather(flat_sequence_tensor, flat_positions) return output_tensor def get_mlm_logits(input_tensor, albert_config, mlm_positions, output_weights): """From run_pretraining.py.""" input_tensor = gather_indexes(input_tensor, mlm_positions) with tf.variable_scope("cls/predictions"): # We apply one more non-linear transformation before the output layer. # This matrix is not used after pre-training. with tf.variable_scope("transform"): input_tensor = tf.layers.dense( input_tensor, units=albert_config.embedding_size, activation=modeling.get_activation(albert_config.hidden_act), kernel_initializer=modeling.create_initializer( albert_config.initializer_range)) input_tensor = modeling.layer_norm(input_tensor) # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. output_bias = tf.get_variable( "output_bias", shape=[albert_config.vocab_size], initializer=tf.zeros_initializer()) logits = tf.matmul( input_tensor, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) return logits def get_sentence_order_logits(input_tensor, albert_config): """Get loss and log probs for the next sentence prediction.""" # Simple binary classification. Note that 0 is "next sentence" and 1 is # "random sentence". This weight matrix is not used after pre-training. with tf.variable_scope("cls/seq_relationship"): output_weights = tf.get_variable( "output_weights", shape=[2, albert_config.hidden_size], initializer=modeling.create_initializer( albert_config.initializer_range)) output_bias = tf.get_variable( "output_bias", shape=[2], initializer=tf.zeros_initializer()) logits = tf.matmul(input_tensor, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) return logits def build_model(sess): """Module function.""" input_ids = tf.placeholder(tf.int32, [None, None], "input_ids") input_mask = tf.placeholder(tf.int32, [None, None], "input_mask") segment_ids = tf.placeholder(tf.int32, [None, None], "segment_ids") mlm_positions = tf.placeholder(tf.int32, [None, None], "mlm_positions") albert_config_path = os.path.join( FLAGS.albert_directory, "albert_config.json") albert_config = modeling.AlbertConfig.from_json_file(albert_config_path) model = modeling.AlbertModel( config=albert_config, is_training=False, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=False) get_mlm_logits(model.get_sequence_output(), albert_config, mlm_positions, model.get_embedding_table()) get_sentence_order_logits(model.get_pooled_output(), albert_config) checkpoint_path = os.path.join(FLAGS.albert_directory, FLAGS.checkpoint_name) tvars = tf.trainable_variables() (assignment_map, initialized_variable_names ) = modeling.get_assignment_map_from_checkpoint(tvars, checkpoint_path) tf.logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) tf.train.init_from_checkpoint(checkpoint_path, assignment_map) init = tf.global_variables_initializer() sess.run(init) return sess def main(_): sess = tf.Session() tf.train.get_or_create_global_step() sess = build_model(sess) my_vars = [] for var in tf.global_variables(): if "lamb_v" not in var.name and "lamb_m" not in var.name: my_vars.append(var) saver = tf.train.Saver(my_vars) saver.save(sess, FLAGS.export_path) if __name__ == "__main__": flags.mark_flag_as_required("albert_directory") flags.mark_flag_as_required("export_path") app.run(main)