Lennard Schober
Add share=True to launch
754f675
import os
import random
import tensorflow as tf
from keras import models
import numpy as np
import gradio as gr
import cv2
# Load the model
try:
generator = models.load_model("generator.keras")
print("Model loaded successfully!")
except Exception as e:
print("Error loading model:", e)
# Function to preprocess the image (resize, normalize)
def preprocess_image(img):
img = cv2.resize(img, (256, 256))
# Convert L to range [-1, 1]
img = img.astype("float32")
img = (img / 127.5) - 1
# Convert to tensor
img = tf.convert_to_tensor(img, dtype=tf.float32)
img = tf.expand_dims(img, axis=-1) # Add image dimension
img = tf.expand_dims(img, axis=0) # Add batch dimension
return img
# Function to postprocess the image (denormalize)
def postprocess_image(img):
return cv2.cvtColor(((img + 1) * 127.5).numpy().astype(np.uint8), cv2.COLOR_LAB2RGB)
# Function to adjust brightness
def adjust_brightness(img, brightness=0.0):
# Apply brightness adjustment
img = cv2.convertScaleAbs(img, beta=int(brightness * 127.0 / 4.0))
return np.uint8(np.clip(img, 0, 255))
# Function to adjust contrast
def adjust_contrast(img, contrast=0.0):
# Apply contrast adjustment
img = cv2.convertScaleAbs(img, alpha=(contrast * 0.75 + 1.0))
return np.uint8(np.clip(img, 0, 255))
# Function to adjust hue
def adjust_hue(img, hue_shift=0.0):
# Convert the image to HSV
hsv_img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
# Adjust the hue channel (value is between 0 and 179 in OpenCV's HSV)
hsv_img[:, :, 0] = (
hsv_img[:, :, 0] + hue_shift * 90
) % 180 # Hue is wrapped in OpenCV HSV format
# Convert back to BGR
img = cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR)
return np.uint8(np.clip(img, 0, 255))
def adjust_saturation(img, saturation_factor=0.0):
# Convert the image to HSV
hsv_img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
# Adjust the saturation channel (index 1 in HSV)
hsv_img[:, :, 1] = np.clip(hsv_img[:, :, 1] * (saturation_factor + 1.0), 0, 255)
# Convert back to BGR
img = cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR)
return np.uint8(np.clip(img, 0, 255))
# Define the inference function
def colorize_image(input_image):
# Preprocess the image for the model
preprocessed_image = preprocess_image(input_image)
# Predict using the model
output_ab = generator.predict(preprocessed_image)
output = tf.concat([preprocessed_image[0], output_ab[0]], axis=-1)
# Postprocess the output
output_image = postprocess_image(output)
return output_image
# Function to colorize and store the result for further manipulation
def colorize_and_store(img, bright_slider, cont_slider, sat_slider, hue_slider):
# Colorize the image
colorized_image = colorize_image(img)
output_image = adjust_brightness(colorized_image, bright_slider)
output_image = adjust_contrast(output_image, cont_slider)
output_image = adjust_saturation(output_image, sat_slider)
output_image = adjust_hue(output_image, hue_slider)
# Return the colorized image for further manipulation (no model call)
return colorized_image, output_image
def make_grayscale_256(img):
img = cv2.resize(img, (256, 256))
# img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
return img
css = """
h1 {
text-align: center;
display:block;
font-size: 3rem;
margin: 0;
padding: 0.5rem;
line-height: 1;
overflow: hidden;
}
p {
text-align: center;
display:block;
font-size:1.5rem;
margin: 0;
padding: 0.5rem;
line-height: 1;
overflow: hidden;
}
#input-image img {
filter: grayscale(1);
}
"""
# Get all image file paths in the folder
image_files = [
os.path.join("examples", file)
for file in os.listdir("examples")
if file.lower().endswith((".png", ".jpg", ".jpeg", ".webp"))
]
# Gradio Interface
with gr.Blocks(css=css) as demo:
demo.title = "Portrait Colorizer"
# title
gr.HTML("<h1>Portrait Colorizer</h1>")
# description
gr.HTML("<p>Upload a grayscale image to colorize it and fine-tune the output using the sliders below.</p>")
with gr.Row():
input_image = gr.Image(
type="numpy",
label="Grayscale Image",
image_mode="L",
height=256,
width=256,
elem_id="input-image",
)
examples_gallery = gr.Examples(
examples=image_files, inputs=[input_image], label="Example Images"
)
output_image = gr.Image(
type="numpy",
label="Colorized Image",
image_mode="RGB",
height=256,
width=256,
)
process_button = gr.Button("Colorize")
bright_slider = gr.Slider(-1.0, 1.0, value=0.0, label="Brightness")
cont_slider = gr.Slider(-1.0, 1.0, value=0.0, label="Contrast")
sat_slider = gr.Slider(-1.0, 1.0, value=0.0, label="Saturation")
hue_slider = gr.Slider(-1.0, 1.0, value=0.0, label="Hue")
# Initially colorize and display the image when it is uploaded
colorized_image = gr.State()
# Button click triggers processing
process_button.click(
fn=colorize_and_store,
inputs=[input_image, bright_slider, cont_slider, sat_slider, hue_slider],
outputs=[colorized_image, output_image],
)
# Apply hue adjustment to the stored colorized image (no re-generation)
bright_slider.change(
fn=adjust_brightness,
inputs=[colorized_image, bright_slider],
outputs=output_image, # Update output image
)
# Apply hue adjustment to the stored colorized image (no re-generation)
cont_slider.change(
fn=adjust_contrast,
inputs=[colorized_image, cont_slider],
outputs=output_image, # Update output image
)
# Apply hue adjustment to the stored colorized image (no re-generation)
hue_slider.change(
fn=adjust_hue,
inputs=[colorized_image, hue_slider],
outputs=output_image, # Update output image
)
# Apply saturation adjustment to the stored colorized image (no re-generation)
sat_slider.change(
fn=adjust_saturation,
inputs=[colorized_image, sat_slider],
outputs=output_image, # Update output image
)
# Launch the app
demo.launch(share=True, ssr_mode=False)