|
import streamlit as st |
|
from PIL import Image |
|
import numpy as np |
|
import tensorflow as tf |
|
|
|
if 'clicked' not in st.session_state: |
|
st.session_state.clicked = False |
|
|
|
def click_button(): |
|
st.session_state.clicked = True |
|
|
|
img_size = 400 |
|
vgg = tf.keras.applications.VGG19(include_top=False, input_shape=(img_size, img_size, 3), weights='vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5') |
|
vgg.trainable = False |
|
|
|
def get_layer_outputs(vgg, layer_names): |
|
outputs = [vgg.get_layer(layer[0]).output for layer in layer_names] |
|
model = tf.keras.Model([vgg.input], outputs) |
|
return model |
|
|
|
STYLE_LAYERS = [('block1_conv1', 0.2), ('block2_conv1', 0.2), ('block3_conv1', 0.2), ('block4_conv1', 0.2), ('block5_conv1', 0.2)] |
|
content_layer = [('block5_conv4', 1)] |
|
vgg_model_outputs = get_layer_outputs(vgg, STYLE_LAYERS + content_layer) |
|
|
|
st.set_page_config(layout="wide") |
|
st.markdown("<h1 style='text-align: center;'>Neural Style Transfer</h1>", unsafe_allow_html=True) |
|
st.divider() |
|
|
|
co1, co2, co3, co4 = st.columns(4) |
|
with co2: |
|
epochs = st.number_input("Input number of epochs", min_value=200, max_value=20000, step=50) |
|
with co3: |
|
st.write(" ") |
|
st.write(" ") |
|
st.button('Generate Art', on_click=click_button, type="primary", use_container_width=True) |
|
|
|
col1, col2, col3 = st.columns(3) |
|
|
|
with col1: |
|
content_img = st.file_uploader("Input Content Image") |
|
if content_img is not None: |
|
content_image = np.array(Image.open(content_img).resize((img_size, img_size))) |
|
content_image = np.expand_dims(content_image, axis=0) |
|
generated_image = tf.Variable(tf.image.convert_image_dtype(content_image, tf.float32)) |
|
noise = tf.random.uniform(tf.shape(generated_image), 0, 0.5) |
|
generated_image = tf.add(generated_image, noise) |
|
generated_image = tf.clip_by_value(generated_image, clip_value_min=0.0, clip_value_max=1.0) |
|
content_target = vgg_model_outputs(content_image) |
|
preprocessed_content = tf.Variable(tf.image.convert_image_dtype(content_image, tf.float32)) |
|
a_C = vgg_model_outputs(preprocessed_content) |
|
a_G = vgg_model_outputs(generated_image) |
|
st.image(content_img, caption="CONTENT IMAGE", use_column_width=True) |
|
|
|
with col2: |
|
style_img = st.file_uploader("Input Style Image") |
|
if style_img is not None: |
|
style_image = np.array(Image.open(style_img).resize((img_size, img_size))) |
|
style_image = np.expand_dims(style_image, axis=0) |
|
style_targets = vgg_model_outputs(style_image) |
|
preprocessed_style = tf.Variable(tf.image.convert_image_dtype(style_image, tf.float32)) |
|
a_S = vgg_model_outputs(preprocessed_style) |
|
st.image(style_img, caption="STYLE IMAGE", use_column_width=True) |
|
|
|
def compute_content_cost(content_output, generated_output): |
|
a_C = content_output[-1] |
|
a_G = generated_output[-1] |
|
m, n_H, n_W, n_C = a_G.get_shape().as_list() |
|
a_C_unrolled = tf.transpose(tf.reshape(a_C, shape=[m, -1, n_C])) |
|
a_G_unrolled = tf.transpose(tf.reshape(a_G, shape=[m, -1, n_C])) |
|
J_content = (1 / (4 * n_H * n_W * n_C)) * tf.reduce_sum(tf.square(tf.subtract(a_C_unrolled, a_G_unrolled))) |
|
return J_content |
|
|
|
def gram_matrix(A): |
|
GA = tf.matmul(A, A, transpose_b=True) |
|
return GA |
|
|
|
def compute_layer_style_cost(a_S, a_G): |
|
m, n_H, n_W, n_C = a_G.get_shape().as_list() |
|
a_S = tf.transpose(tf.reshape(a_S, shape=[-1, n_C])) |
|
a_G = tf.transpose(tf.reshape(a_G, shape=[-1, n_C])) |
|
GS = gram_matrix(a_S) |
|
GG = gram_matrix(a_G) |
|
J_style_layer = (1 / (4 * n_C ** 2 * (n_H * n_W) ** 2)) * tf.reduce_sum(tf.square(tf.subtract(GS, GG))) |
|
return J_style_layer |
|
|
|
def compute_style_cost(style_image_output, generated_image_output, STYLE_LAYERS=STYLE_LAYERS): |
|
J_style = 0 |
|
a_S = style_image_output[:-1] |
|
a_G = generated_image_output[:-1] |
|
for i, weight in zip(range(len(a_S)), STYLE_LAYERS): |
|
J_style_layer = compute_layer_style_cost(a_S[i], a_G[i]) |
|
J_style += weight[1] * J_style_layer |
|
return J_style |
|
|
|
@tf.function() |
|
def total_cost(J_content, J_style, alpha=10, beta=40): |
|
J = alpha * J_content + beta * J_style |
|
return J |
|
|
|
def clip_0_1(image): |
|
return tf.clip_by_value(image, clip_value_min=0.0, clip_value_max=1.0) |
|
|
|
def tensor_to_image(tensor): |
|
tensor = tensor * 255 |
|
tensor = np.array(tensor, dtype=np.uint8) |
|
if np.ndim(tensor) > 3: |
|
assert tensor.shape[0] == 1 |
|
tensor = tensor[0] |
|
return Image.fromarray(tensor) |
|
|
|
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001) |
|
|
|
@tf.function() |
|
def train_step(generated_image): |
|
with tf.GradientTape() as tape: |
|
a_G = vgg_model_outputs(generated_image) |
|
J_style = compute_style_cost(a_S, a_G) |
|
J_content = compute_content_cost(a_C, a_G) |
|
J = total_cost(J_content, J_style) |
|
grad = tape.gradient(J, generated_image) |
|
optimizer.apply_gradients([(grad, generated_image)]) |
|
generated_image.assign(clip_0_1(generated_image)) |
|
return J |
|
|
|
with col3: |
|
st.write("Generated Image") |
|
if st.session_state.clicked: |
|
generated_image = tf.Variable(tf.image.convert_image_dtype(content_image, tf.float32)) |
|
st.write(" ") |
|
st.write(" ") |
|
st.write(" ") |
|
placeholder_1 = st.empty() |
|
placeholder_2 = st.empty() |
|
for I in range(epochs): |
|
train_step(generated_image) |
|
if I % 1 == 0: |
|
image = tensor_to_image(generated_image) |
|
placeholder_1.image(image) |
|
placeholder_2.write(f"Epoch {I}") |