File size: 2,901 Bytes
43acc4d
b9b231f
429d253
 
 
 
b9b231f
 
 
429d253
b9b231f
07ce7ff
 
429d253
 
3fdd1de
429d253
 
3fdd1de
429d253
 
 
 
43acc4d
429d253
 
 
 
43acc4d
429d253
 
 
b9b231f
429d253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43acc4d
 
b9b231f
 
 
 
 
 
 
 
 
 
 
43acc4d
b9b231f
429d253
b9b231f
 
ff9b5fb
07ce7ff
 
944ffcd
43acc4d
07ce7ff
 
 
 
 
 
 
 
 
 
0170510
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import uuid
import base64
import requests
from PIL import Image
from io import BytesIO
from pathlib import Path
import gradio as gr
from gradio_imageslider import ImageSlider  # Ensure this library is installed
from dotenv import load_dotenv

import config

# Load environment variables from the .env file
load_dotenv()

# Get API key from environment variable
api_key = os.getenv('API_KEY')

# Funzione per chiamare l'endpoint di predizione FastAPI
def process_image(input_image_editor):
    input_image = input_image_editor['background']
    mask_image = input_image_editor['layers'][0]

    # Converti le immagini in base64
    buffered_input = BytesIO()
    input_image.save(buffered_input, format="PNG")
    input_image_base64 = base64.b64encode(buffered_input.getvalue()).decode()

    buffered_mask = BytesIO()
    mask_image.save(buffered_mask, format="PNG")
    mask_image_base64 = base64.b64encode(buffered_mask.getvalue()).decode()

    # Prepara il payload per la richiesta POST
    payload = {
        "input_image_editor": {
            "background": input_image_base64,
            "layers": [mask_image_base64]
        }
    }

    # Effettua la richiesta POST al backend FastAPI
    response = requests.post(
        os.getenv('BACKEND_URL') + "/predict/",
        headers={"access_token": api_key},
        json=payload
    )

    if response.status_code == 200:
        result = response.json()
        processed_image_base64 = result['processed_image']
        processed_image = Image.open(BytesIO(base64.b64decode(processed_image_base64)))

        # Save the processed image
        output_folder = Path("output")  # Make sure this folder exists or create it
        output_folder.mkdir(parents=True, exist_ok=True)
        image_path = output_folder / f"no_bg_image_{uuid.uuid4().hex}.png"
        processed_image.save(image_path)

        return (processed_image, input_image), str(image_path)

    else:
        raise Exception(f"Request failed with status code {response.status_code}")



# Define inputs and outputs for the Gradio interface
image = gr.ImageEditor(
                label='Image',
                type='pil',
                sources=["upload", "webcam"],
                image_mode='RGB',
                layers=False,
                brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed")
            )
output_slider = ImageSlider(label="Processed photo", type="pil")

demo = gr.Interface(
    fn=process_image,
    inputs=image,
    outputs=[output_slider, gr.File(label="output png file")],
    title=config.TITLE,
    description=config.DESCRIPTION,
    article=config.BUY_ME_A_COFFE
)

#Center the title and description using custom CSS
demo.css = """  
    .interface-title {  
        text-align: center;  
    }  
    .interface-description {  
        text-align: center;  
    }  
"""

demo.launch(debug=False, show_error=True, share=True)