Spaces:
Running
Running
import streamlit as st | |
from PIL import Image | |
import torch | |
from RealESRGAN import RealESRGAN | |
from io import BytesIO | |
# Function to load the model based on scale and anime toggle | |
def load_model(scale, anime=False): | |
try: | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = RealESRGAN(device, scale=scale, anime=anime) | |
model_path = { | |
(2, False): 'model/RealESRGAN_x2.pth', | |
(4, False): 'model/RealESRGAN_x4plus.pth', | |
(8, False): 'model/RealESRGAN_x8.pth', | |
(4, True): 'model/RealESRGAN_x4plus_anime_6B.pth' | |
}[(scale, anime)] | |
model.load_weights(model_path) | |
return model | |
except Exception as e: | |
st.error(f"Failed to load the model: {e}") | |
return None | |
def enhance_image(image, scale, anime): | |
try: | |
model = load_model(scale, anime=anime) | |
if model is None: | |
return None, None | |
# Convert image to RGB if it has an alpha channel | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
sr_image = model.predict(image) | |
buffer = BytesIO() | |
sr_image.save(buffer, format="PNG") | |
buffer.seek(0) | |
return sr_image, buffer | |
except Exception as e: | |
st.error(f"An error occurred during image enhancement: {e}") | |
return None, None | |
def main(): | |
st.title("Generative AI Image Restoration") | |
# Image upload | |
uploaded_image = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg"]) | |
if uploaded_image is not None: | |
try: | |
image = Image.open(uploaded_image) | |
# Anime toggle | |
anime = st.checkbox("Anime Image", value=False) | |
# Conditional scale options | |
if anime: | |
scale = "4x" # Set to 4x automatically when anime is selected | |
else: | |
scale = st.radio("Upscaling Factor", ["2x", "4x", "8x"], index=0) | |
scale_value = int(scale.replace('x', '')) | |
# Enhance button | |
if st.button("Restore Image"): | |
enhanced_image, buffer = enhance_image(image, scale_value, anime) | |
if enhanced_image: | |
# Show images side by side | |
col1, col2 = st.columns(2) | |
with col1: | |
st.image(image, caption="Original Image", use_column_width=True) | |
with col2: | |
st.image(enhanced_image, caption="Enhanced Image", use_column_width=True) | |
# Download button | |
st.download_button( | |
label="Download Enhanced Image", | |
data=buffer, | |
file_name="enhanced_image.png", | |
mime="image/png" | |
) | |
except Exception as e: | |
st.error(f"An error occurred while processing the image: {e}") | |
if __name__ == "__main__": | |
main() | |