Spaces:
Running
Running
# Copyright 2016 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Adversarial training to learn trivial encryption functions, | |
from the paper "Learning to Protect Communications with | |
Adversarial Neural Cryptography", Abadi & Andersen, 2016. | |
https://arxiv.org/abs/1610.06918 | |
This program creates and trains three neural networks, | |
termed Alice, Bob, and Eve. Alice takes inputs | |
in_m (message), in_k (key) and outputs 'ciphertext'. | |
Bob takes inputs in_k, ciphertext and tries to reconstruct | |
the message. | |
Eve is an adversarial network that takes input ciphertext | |
and also tries to reconstruct the message. | |
The main function attempts to train these networks and then | |
evaluates them, all on random plaintext and key values. | |
""" | |
# TensorFlow Python 3 compatibility | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import signal | |
import sys | |
from six.moves import xrange # pylint: disable=redefined-builtin | |
import tensorflow as tf | |
flags = tf.app.flags | |
flags.DEFINE_float('learning_rate', 0.0008, 'Constant learning rate') | |
flags.DEFINE_integer('batch_size', 4096, 'Batch size') | |
FLAGS = flags.FLAGS | |
# Input and output configuration. | |
TEXT_SIZE = 16 | |
KEY_SIZE = 16 | |
# Training parameters. | |
ITERS_PER_ACTOR = 1 | |
EVE_MULTIPLIER = 2 # Train Eve 2x for every step of Alice/Bob | |
# Train until either max loops or Alice/Bob "good enough": | |
MAX_TRAINING_LOOPS = 850000 | |
BOB_LOSS_THRESH = 0.02 # Exit when Bob loss < 0.02 and Eve > 7.7 bits | |
EVE_LOSS_THRESH = 7.7 | |
# Logging and evaluation. | |
PRINT_EVERY = 200 # In training, log every 200 steps. | |
EVE_EXTRA_ROUNDS = 2000 # At end, train eve a bit more. | |
RETRAIN_EVE_ITERS = 10000 # Retrain eve up to ITERS*LOOPS times. | |
RETRAIN_EVE_LOOPS = 25 # With an evaluation each loop | |
NUMBER_OF_EVE_RESETS = 5 # And do this up to 5 times with a fresh eve. | |
# Use EVAL_BATCHES samples each time we check accuracy. | |
EVAL_BATCHES = 1 | |
def batch_of_random_bools(batch_size, n): | |
"""Return a batch of random "boolean" numbers. | |
Args: | |
batch_size: Batch size dimension of returned tensor. | |
n: number of entries per batch. | |
Returns: | |
A [batch_size, n] tensor of "boolean" numbers, where each number is | |
preresented as -1 or 1. | |
""" | |
as_int = tf.random.uniform( | |
[batch_size, n], minval=0, maxval=2, dtype=tf.int32) | |
expanded_range = (as_int * 2) - 1 | |
return tf.cast(expanded_range, tf.float32) | |
class AdversarialCrypto(object): | |
"""Primary model implementation class for Adversarial Neural Crypto. | |
This class contains the code for the model itself, | |
and when created, plumbs the pathways from Alice to Bob and | |
Eve, creates the optimizers and loss functions, etc. | |
Attributes: | |
eve_loss: Eve's loss function. | |
bob_loss: Bob's loss function. Different units from eve_loss. | |
eve_optimizer: A tf op that runs Eve's optimizer. | |
bob_optimizer: A tf op that runs Bob's optimizer. | |
bob_reconstruction_loss: Bob's message reconstruction loss, | |
which is comparable to eve_loss. | |
reset_eve_vars: Execute this op to completely reset Eve. | |
""" | |
def get_message_and_key(self): | |
"""Generate random pseudo-boolean key and message values.""" | |
batch_size = tf.compat.v1.placeholder_with_default(FLAGS.batch_size, shape=[]) | |
in_m = batch_of_random_bools(batch_size, TEXT_SIZE) | |
in_k = batch_of_random_bools(batch_size, KEY_SIZE) | |
return in_m, in_k | |
def model(self, collection, message, key=None): | |
"""The model for Alice, Bob, and Eve. If key=None, the first fully connected layer | |
takes only the message as inputs. Otherwise, it uses both the key | |
and the message. | |
Args: | |
collection: The graph keys collection to add new vars to. | |
message: The input message to process. | |
key: The input key (if any) to use. | |
""" | |
if key is not None: | |
combined_message = tf.concat(axis=1, values=[message, key]) | |
else: | |
combined_message = message | |
# Ensure that all variables created are in the specified collection. | |
with tf.contrib.framework.arg_scope( | |
[tf.contrib.layers.fully_connected, tf.contrib.layers.conv2d], | |
variables_collections=[collection]): | |
fc = tf.contrib.layers.fully_connected( | |
combined_message, | |
TEXT_SIZE + KEY_SIZE, | |
biases_initializer=tf.constant_initializer(0.0), | |
activation_fn=None) | |
# Perform a sequence of 1D convolutions (by expanding the message out to 2D | |
# and then squeezing it back down). | |
fc = tf.expand_dims(fc, 2) # 2D | |
fc = tf.expand_dims(fc, 3) # 3D -- conv2d needs a depth | |
# 2,1 -> 1,2 | |
conv = tf.contrib.layers.conv2d( | |
fc, 2, 2, 2, 'SAME', activation_fn=tf.nn.sigmoid) | |
# 1,2 -> 1, 2 | |
conv = tf.contrib.layers.conv2d( | |
conv, 2, 1, 1, 'SAME', activation_fn=tf.nn.sigmoid) | |
# 1,2 -> 1, 1 | |
conv = tf.contrib.layers.conv2d( | |
conv, 1, 1, 1, 'SAME', activation_fn=tf.nn.tanh) | |
conv = tf.squeeze(conv, 3) | |
conv = tf.squeeze(conv, 2) | |
return conv | |
def __init__(self): | |
in_m, in_k = self.get_message_and_key() | |
encrypted = self.model('alice', in_m, in_k) | |
decrypted = self.model('bob', encrypted, in_k) | |
eve_out = self.model('eve', encrypted, None) | |
self.reset_eve_vars = tf.group( | |
*[w.initializer for w in tf.compat.v1.get_collection('eve')]) | |
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=FLAGS.learning_rate) | |
# Eve's goal is to decrypt the entire message: | |
eve_bits_wrong = tf.reduce_sum( | |
tf.abs((eve_out + 1.0) / 2.0 - (in_m + 1.0) / 2.0), [1]) | |
self.eve_loss = tf.reduce_sum(eve_bits_wrong) | |
self.eve_optimizer = optimizer.minimize( | |
self.eve_loss, var_list=tf.compat.v1.get_collection('eve')) | |
# Alice and Bob want to be accurate... | |
self.bob_bits_wrong = tf.reduce_sum( | |
tf.abs((decrypted + 1.0) / 2.0 - (in_m + 1.0) / 2.0), [1]) | |
# ... and to not let Eve do better than guessing. | |
self.bob_reconstruction_loss = tf.reduce_sum(self.bob_bits_wrong) | |
bob_eve_error_deviation = tf.abs(float(TEXT_SIZE) / 2.0 - eve_bits_wrong) | |
# 7-9 bits wrong is OK too, so we squish the error function a bit. | |
# Without doing this, we often tend to hang out at 0.25 / 7.5 error, | |
# and it seems bad to have continued, high communication error. | |
bob_eve_loss = tf.reduce_sum( | |
tf.square(bob_eve_error_deviation) / (TEXT_SIZE / 2)**2) | |
# Rescale the losses to [0, 1] per example and combine. | |
self.bob_loss = (self.bob_reconstruction_loss / TEXT_SIZE + bob_eve_loss) | |
self.bob_optimizer = optimizer.minimize( | |
self.bob_loss, | |
var_list=(tf.compat.v1.get_collection('alice') + tf.compat.v1.get_collection('bob'))) | |
def doeval(s, ac, n, itercount): | |
"""Evaluate the current network on n batches of random examples. | |
Args: | |
s: The current TensorFlow session | |
ac: an instance of the AdversarialCrypto class | |
n: The number of iterations to run. | |
itercount: Iteration count label for logging. | |
Returns: | |
Bob and Eve's loss, as a percent of bits incorrect. | |
""" | |
bob_loss_accum = 0 | |
eve_loss_accum = 0 | |
for _ in xrange(n): | |
bl, el = s.run([ac.bob_reconstruction_loss, ac.eve_loss]) | |
bob_loss_accum += bl | |
eve_loss_accum += el | |
bob_loss_percent = bob_loss_accum / (n * FLAGS.batch_size) | |
eve_loss_percent = eve_loss_accum / (n * FLAGS.batch_size) | |
print('%10d\t%20.2f\t%20.2f'%(itercount, bob_loss_percent, eve_loss_percent)) | |
sys.stdout.flush() | |
return bob_loss_percent, eve_loss_percent | |
def train_until_thresh(s, ac): | |
for j in xrange(MAX_TRAINING_LOOPS): | |
for _ in xrange(ITERS_PER_ACTOR): | |
s.run(ac.bob_optimizer) | |
for _ in xrange(ITERS_PER_ACTOR * EVE_MULTIPLIER): | |
s.run(ac.eve_optimizer) | |
if j % PRINT_EVERY == 0: | |
bob_avg_loss, eve_avg_loss = doeval(s, ac, EVAL_BATCHES, j) | |
if (bob_avg_loss < BOB_LOSS_THRESH and eve_avg_loss > EVE_LOSS_THRESH): | |
print('Target losses achieved.') | |
return True | |
return False | |
def train_and_evaluate(): | |
"""Run the full training and evaluation loop.""" | |
ac = AdversarialCrypto() | |
init = tf.compat.v1.global_variables_initializer() | |
with tf.compat.v1.Session() as s: | |
s.run(init) | |
print('# Batch size: ', FLAGS.batch_size) | |
print('# %10s\t%20s\t%20s'%("Iter","Bob_Recon_Error","Eve_Recon_Error")) | |
if train_until_thresh(s, ac): | |
for _ in xrange(EVE_EXTRA_ROUNDS): | |
s.run(ac.eve_optimizer) | |
print('Loss after eve extra training:') | |
doeval(s, ac, EVAL_BATCHES * 2, 0) | |
for _ in xrange(NUMBER_OF_EVE_RESETS): | |
print('Resetting Eve') | |
s.run(ac.reset_eve_vars) | |
eve_counter = 0 | |
for _ in xrange(RETRAIN_EVE_LOOPS): | |
for _ in xrange(RETRAIN_EVE_ITERS): | |
eve_counter += 1 | |
s.run(ac.eve_optimizer) | |
doeval(s, ac, EVAL_BATCHES, eve_counter) | |
doeval(s, ac, EVAL_BATCHES, eve_counter) | |
def main(unused_argv): | |
# Exit more quietly with Ctrl-C. | |
signal.signal(signal.SIGINT, signal.SIG_DFL) | |
train_and_evaluate() | |
if __name__ == '__main__': | |
tf.compat.v1.app.run() | |