RetinaGAN / app.py
farrell236's picture
Update app.py
ed96423
import cv2
import numpy as np
import streamlit as st
import tensorflow as tf
from models.cstylegan import cStyleGAN
from models.gaugan import GauGAN
from utils import fix_pred_label, onehot_to_rgb, rgb_to_onehot, color_dict
from skimage import io
@st.cache_resource
def load_cstylegan():
conditional_style_gan = cStyleGAN(start_res=4, target_res=1024)
conditional_style_gan.grow_model(256)
conditional_style_gan.load_weights('checkpoints/cstylegan/cstylegan_256x256.ckpt').expect_partial()
print('Conditional StyleGAN Model Loaded!')
return conditional_style_gan
@st.cache_resource
def load_gaugan(batch_size):
gaugan = GauGAN(image_size=1024, num_classes=7, batch_size=batch_size, latent_dim=512)
gaugan.load_weights('checkpoints/gaugan/gaugan_1024x1024.ckpt').expect_partial()
print('GauGAN Model Loaded!')
return gaugan
def set_seed():
tf.random.set_seed(seed=st.session_state.seed)
def main():
st.title('RetinaGAN')
st.sidebar.columns([1, 5, 1])[1].image(cv2.cvtColor(cv2.imread('assets/sample.jpeg'), cv2.COLOR_BGR2RGB))
st.sidebar.title('Menu')
options = st.sidebar.selectbox('Select Option:', ('About', 'Random', 'Upload your own', 'Retina Template'))
if options == 'About':
st.write('Online Demo for **High-Fidelity Diabetic Retina Fundus Image Synthesis from Freestyle Lesion Maps**')
st.write('''
Paper: https://opg.optica.org/abstract.cfm?uri=boe-14-2-533
Github: http://github.com/farrell236/RetinaGAN
πŸ‘ˆ Select an Option From the drop down menu
---
''')
st.write('''
RetinaGAN a two-step process for generating photo-realistic retinal
Fundus images based on artificially generated or free-hand drawn semantic lesion maps.
''')
st.columns([1, 5, 1])[1].image(cv2.cvtColor(cv2.imread('assets/RetinaGAN_pipeline.png'), cv2.COLOR_BGR2RGB),
caption='RetinaGAN Pipeline')
st.write('''
StyleGAN is modified to be conditional in to synthesize pathological lesion maps
based on a specified DR grade (i.e., grades 0 to 4). The DR Grades are defined by the
International Clinical Diabetic Retinopathy (ICDR) disease severity scale;
no apparent retinopathy, {mild, moderate, severe} Non-Proliferative Diabetic Retinopathy (NPDR),
and Proliferative Diabetic Retinopathy (PDR). The output of the network is a binary image with
seven channels instead of class colors to avoid ambiguity.
''')
st.columns([1, 5, 1])[1].image(cv2.cvtColor(cv2.imread('assets/cStyleGAN.png'), cv2.COLOR_BGR2RGB),
caption='Conditional StyleGAN Model')
st.write('''
The generated label maps are then passed through GauGAN, an image-to-image translation network,
to turn them into photo-realistic retina fundus images. The input to the network are one-hot
encoded labels.
''')
st.columns([1, 5, 1])[1].image(cv2.cvtColor(cv2.imread('assets/GauGAN.png'), cv2.COLOR_BGR2RGB),
caption='GauGAN Model')
elif options == 'Random':
st.session_state.seed = st.sidebar.number_input('Sampling Seed:', value=42, on_change=set_seed)
## Load Models
conditional_style_gan = load_cstylegan()
gaugan = load_gaugan(4)
for idx, col in enumerate(st.columns(5)):
z = tf.random.normal((1, conditional_style_gan.z_dim))
w = conditional_style_gan.mapping([z, conditional_style_gan.embedding(idx)])
noise = conditional_style_gan.generate_noise(batch_size=1)
labels = conditional_style_gan.call({"style_code": w, "noise": noise, "alpha": 1.0, "class_label": idx})
labels = tf.keras.backend.softmax(labels)
labels = tf.cast(labels > 0.5, dtype=tf.float32)
labels = tf.image.resize(labels, (1024, 1024), method='nearest')
fixed_labels = fix_pred_label(labels)
fixed_labels = tf.tile(fixed_labels, (4, 1, 1, 1))
latent_vector = tf.random.normal(shape=(4, 512), mean=0.0, stddev=2.0)
fake_image = gaugan.predict([latent_vector, fixed_labels])
with col:
st.text(f'DR Grade {idx}')
st.image(onehot_to_rgb(fixed_labels[0], color_dict), output_format='PNG')
for im in fake_image:
st.image(im)
# Run again?
st.button('Regenerate Images')
elif options == 'Upload your own':
st.session_state.seed = st.sidebar.number_input('Sampling Seed:', value=42, on_change=set_seed)
## Load Models
gaugan = load_gaugan(1)
uploaded_file = st.file_uploader('Choose an image...', type=('png'))
if uploaded_file:
col1, col2 = st.columns(2)
# Read input image with size [H, W, 3] and range (0, 255)
img_array = io.imread(uploaded_file)[..., 0:3]
# Test for valid mask
test_colours = np.unique(img_array.reshape(-1, img_array.shape[2]), axis=0)
if not all([tuple(x) in color_dict.values() for x in test_colours]):
st.info('Mask Contains invalid Class Colours')
return
# Resize image with padding to [1024, 1024, 3]
img_array = tf.image.resize_with_pad(img_array, 1024, 1024, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
# Display input image
with col1:
st.image(img_array.numpy(), caption='Uploaded Image')
img_label = rgb_to_onehot(img_array.numpy(), color_dict)[None, ...]
latent_vector = tf.random.normal(shape=(1, 512), mean=0.0, stddev=2.0)
fake_image = gaugan.predict([latent_vector, img_label])[0]
with col2:
st.image(fake_image, caption='Generated Image')
# Run again?
st.button('Regenerate Image')
elif options == 'Retina Template':
st.header('Template')
st.write('Download the Retina Template image below. '
'Using an image editor of your choice, paint lesions '
'into the Vitreous Body and upload it to the model. '
'NB: Images must be stored as lossless PNGs')
template = np.uint8(cv2.circle(np.zeros((1024, 1024, 3)), [512, 512], 512, (255, 255, 255), -1))
st.columns([1, 5, 1])[1].image(template, use_column_width=True, output_format='PNG')
st.header('Class Colours')
cols = st.columns(7)
for idx, cls in enumerate(color_dict):
with cols[idx]:
st.image(image=np.tile(color_dict[cls], (32, 32, 1)),
caption=cls,
output_format='PNG')
# st.caption(color_dict[cls])
data = {'Class Name': [
'Background',
'Hard Exudate',
'Hemohedge',
'Soft Exudate',
'Micro Aneurysms',
'Optical Disc',
'Vitreous Body'],
'RGB Colour': [
str(color_dict[0]), # BG
str(color_dict[1]), # EX
str(color_dict[2]), # HE
str(color_dict[3]), # SE
str(color_dict[4]), # MA
str(color_dict[5]), # OD
str(color_dict[6])] # VB
}
st.table(data)
if __name__ == '__main__':
# tf.config.set_visible_devices([], 'GPU')
main()