magic-eraser / app.py
sab
token client
4b5b8f2
import os
import uuid
import base64
import requests
import numpy as np
from PIL import Image
from io import BytesIO
from pathlib import Path
from dotenv import load_dotenv
import gradio as gr
from gradio_imageslider import ImageSlider # Ensure this library is installed
# Load environment variables from the .env file
load_dotenv()
# Define the output folder
output_folder = Path('output_images')
output_folder.mkdir(exist_ok=True)
def numpy_to_pil(image: np.ndarray) -> Image.Image:
"""Convert a numpy array to a PIL Image."""
mode = "RGB" if image.dtype == np.uint8 else "F"
return Image.fromarray(image.astype('uint8'), mode)
def process_image(image: np.ndarray):
"""
Process the input image by sending it to the backend and saving the output.
Args:
image (np.ndarray): Input image in numpy array format.
Returns:
tuple: Processed images and the path to the saved image.
"""
# Convert numpy array to PIL Image
image_pil = numpy_to_pil(image)
# Encode image to base64
buffered = BytesIO()
image_pil.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
# Get API key from environment variable
api_key = os.getenv('API_KEY')
if not api_key:
raise ValueError("API_KEY is not set in the environment variables")
# Send image to backend with API key in headers
response = requests.post(
os.getenv('BACKEND_URL') + "/process_image/",
headers={"access_token": api_key},
files={"file": ("image.png", base64.b64decode(img_str), "image/png")}
)
# Check if the response is successful
if response.status_code != 200:
raise Exception(f"Request failed with status code {response.status_code}: {response.text}")
# Process the response
result = response.json()
processed_image_b64 = result["processed_image"]
processed_image = Image.open(BytesIO(base64.b64decode(processed_image_b64)))
# Save the processed image
output_folder = Path("output") # Make sure this folder exists or create it
output_folder.mkdir(parents=True, exist_ok=True)
image_path = output_folder / f"no_bg_image_{uuid.uuid4().hex}.png"
processed_image.save(image_path)
return (processed_image, image_pil), str(image_path)
# Define inputs and outputs for the Gradio interface
image = gr.Image(label="Upload a photo")
output_slider = ImageSlider(label="Processed photo", type="pil")
demo = gr.Interface(
fn=process_image,
inputs=image,
outputs=[output_slider, gr.File(label="output png file")],
title="Magic Eraser",
examples=[
["images/elephant.jpg"],
["images/lion.png"],
["images/tartaruga.png"],
]
)
if __name__ == "__main__":
demo.launch()