Spaces:
Running
Running
import streamlit as st | |
import requests | |
from PIL import Image | |
from io import BytesIO | |
from fastapi import FastAPI, File, UploadFile | |
from fastapi.responses import StreamingResponse | |
from tensorflow.keras.models import load_model | |
import numpy as np | |
import io | |
import warnings | |
# Set Streamlit page configuration | |
st.set_page_config( | |
page_title="Sketch to Image using GAN", | |
layout="centered", | |
page_icon="ποΈ", | |
initial_sidebar_state="expanded", | |
) | |
# Custom CSS for styling | |
st.markdown( | |
""" | |
<style> | |
body { | |
background-color: #2a2a2a; /* Cool dark background */ | |
color:#ffffff; /* White text */ | |
font-family: 'Courier New', monospace; /* Cool font */ | |
} | |
h1, h2 { | |
color: #ff6347; /* Tomato color for titles */ | |
font-weight: bold; /* Bold text */ | |
} | |
.stButton>button { | |
color: #ffffff; | |
background-color: #ff6347; /* Tomato color for buttons */ | |
border-radius: 10px; | |
border: 2px solid #ffffff; | |
font-weight: bold; /* Bold text */ | |
} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
# Title with colors and emojis | |
st.markdown("<h1 style='text-align: center; color: #ff6347;'>Sketch to Image using GAN ποΈ</h1>", unsafe_allow_html=True) | |
# Description Section | |
st.markdown("<h2 style='text-align: center; color: #ff6347;'>Empowering Multiple Fields with GANs π</h2>", unsafe_allow_html=True) | |
# Logo Image | |
logo_image = Image.open("home1.jpeg") | |
st.image(logo_image, width=300) | |
st.write("The application of Generative Adversarial Networks (GANs) in the Sketch to Image project extends beyond creative endeavors, finding significant utility in various fields. The ability to transform sketches into vibrant and detailed images has far-reaching implications, especially in sectors such as law enforcement, forensic science, and more.") | |
# Upload Pic Section | |
st.markdown("<h2 style='text-align: center; color: #ff6347;'>Upload Your Sketch π€</h2>", unsafe_allow_html=True) | |
uploaded_file = st.file_uploader("Choose an image... π€", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
# Display the uploaded image in the center | |
col1, col2, col3 = st.columns([1, 2, 1]) | |
with col2: | |
st.image(uploaded_file, caption="Uploaded Image πΌοΈ", width=300) | |
# Button to generate the image with emoji | |
if st.button('Generate π'): | |
# Display a message while generating the image | |
with st.spinner('Wait for it... Generating your image π¨'): | |
try: | |
# Prepare the file for sending | |
files = {"file": uploaded_file.getvalue()} | |
# Send POST request to FastAPI server | |
response = requests.post("http://127.0.0.1:8000/generate-image/", files=files) | |
if response.status_code == 200: | |
# Convert the response content to an image | |
generated_image = Image.open(BytesIO(response.content)) | |
# Display the generated image in the center | |
col1, col2, col3 = st.columns([1, 2, 1]) | |
with col2: | |
st.image(generated_image, caption="Generated Image β¨", width=300) | |
else: | |
st.error("Error in image generation π’") | |
except requests.ConnectionError: | |
st.error("Unable to connect to the FastAPI server. Please make sure it is running.") | |
# FastAPI Section | |
warnings.filterwarnings('ignore') | |
generator_model = load_model('last_13k_data_generator_model.h5') # Update this with your generator model's path | |
app = FastAPI() | |
async def generate_image(file: UploadFile = File(...)): | |
contents = await file.read() | |
image = Image.open(io.BytesIO(contents)).convert('RGB') | |
image = image.resize((256, 256)) | |
image_array = np.array(image) | |
image_array = (image_array - 127.5) / 127.5 | |
image_array = np.expand_dims(image_array, axis=0) | |
fake_image = generator_model.predict(image_array) | |
fake_image = (fake_image + 1) / 2.0 | |
fake_image = np.squeeze(fake_image) | |
fake_image = (fake_image * 255).astype(np.uint8) | |
fake_image = Image.fromarray(fake_image) | |
img_io = io.BytesIO() | |
fake_image.save(img_io, 'JPEG', quality=70) | |
img_io.seek(0) | |
return StreamingResponse(img_io, media_type='image/jpeg') | |
if __name__ == '__main__': | |
import uvicorn | |
uvicorn.run(app, host='127.0.0.1', port=8000) | |