HelloAI1's picture
Update app.py
2eaf1c2 verified
import gradio as gr
import requests
import time
import json
import base64
import os
from PIL import Image
from io import BytesIO
class Prodia:
def __init__(self, api_key, base=None):
self.base = base or "https://api.prodia.com/v1"
self.headers = {
"X-Prodia-Key": api_key
}
def generate(self, params):
response = self._post(f"{self.base}/sdxl/generate", params)
return response.json()
def get_job(self, job_id):
response = self._get(f"{self.base}/job/{job_id}")
return response.json()
def wait(self, job):
job_result = job
while job_result['status'] not in ['succeeded', 'failed']:
time.sleep(0.25)
job_result = self.get_job(job['job'])
return job_result
def list_models(self):
response = self._get(f"{self.base}/sdxl/models")
return response.json()
def list_samplers(self):
response = self._get(f"{self.base}/sdxl/samplers")
return response.json()
def _post(self, url, params):
headers = {
**self.headers,
"Content-Type": "application/json"
}
response = requests.post(url, headers=headers, data=json.dumps(params))
if response.status_code != 200:
raise Exception(f"Bad Prodia Response: {response.status_code}")
return response
def _get(self, url):
response = requests.get(url, headers=self.headers)
if response.status_code != 200:
raise Exception(f"Bad Prodia Response: {response.status_code}")
return response
def image_to_base64(image_path):
# Open the image with PIL
with Image.open(image_path) as image:
# Convert the image to bytes
buffered = BytesIO()
image.save(buffered, format="PNG") # You can change format to PNG if needed
# Encode the bytes to base64
img_str = base64.b64encode(buffered.getvalue())
return img_str.decode('utf-8') # Convert bytes to string
prodia_client = Prodia(api_key=os.getenv("PRODIA_API_KEY"))
def flip_text(prompt, negative_prompt, model, steps, sampler, cfg_scale, width, height, seed):
result = prodia_client.generate({
"prompt": prompt,
"negative_prompt": negative_prompt,
"model": model,
"steps": steps,
"sampler": sampler,
"cfg_scale": cfg_scale,
"width": width,
"height": height,
"seed": seed
})
job = prodia_client.wait(result)
return job["imageUrl"]
css = """
/* Overall Styling */
body {
font-family: 'Arial', sans-serif;
}
.container {
display: flex;
flex-direction: column;
gap: 20px;
}
/* Image Output Area */
#image-output-container {
border: 2px solid #ccc;
border-radius: 8px;
overflow: hidden;
}
#image-output {
max-width: 100%;
height: auto;
}
/* Settings Section */
#settings {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(250px, 1fr));
gap: 20px;
}
.setting-group {
border: 1px solid #ccc;
padding: 20px;
border-radius: 8px;
}
/* Button Styling */
#generate {
background-color: #007bff; /* Example - use your preferred color */
color: white;
padding: 15px 25px;
border: none;
border-radius: 5px;
cursor: pointer;
}
#generate:hover {
background-color: #0056b3; /* Darker shade on hover */
}
/* Responsive Design - Adjust breakpoints as needed */
@media screen and (max-width: 768px) {
#settings {
grid-template-columns: 1fr;
}
}
"""
# --- Gradio Interface ---
with gr.Blocks(css=css) as demo:
state = gr.State(value="Welcome Screen") # To control the visibility of tabs
with gr.Tabs() as tabs:
with gr.TabItem("Welcome Screen"):
with gr.Row():
logo = gr.Image(
value="http://disneypixaraigenerator.com/wp-content/uploads/2023/12/cropped-android-chrome-512x512-1.png",
elem_id="logo",
height=200,
width=300
)
with gr.Row():
title = gr.Markdown("<h1 style='text-align: center;'>Disney Pixar AI Generator</h1>", elem_id="title")
with gr.Row():
start_button = gr.Button("Get Started", variant='primary', elem_id="start-button")
with gr.TabItem("Main Generation Screen"):
with gr.Row():
gr.Markdown("<h1 style='text-align: center;'>Create Your Disney Pixar AI Poster</h1>", elem_id="title")
with gr.Row(elem_id="image-output-container"):
image_output = gr.Image(
value="https://cdn-uploads.huggingface.co/production/uploads/noauth/XWJyh9DhMGXrzyRJk7SfP.png",
label="Generated Image",
elem_id="image-output"
)
with gr.Row(elem_id="settings"):
with gr.Column(scale=1, min_width=300, elem_classes="setting-group"):
prompt = gr.Textbox(
"space warrior, beautiful, female, ultrarealistic, soft lighting, 8k",
placeholder="Enter your prompt here...",
show_label=False,
lines=3,
elem_id="prompt-input"
)
negative_prompt = gr.Textbox(
placeholder="Enter negative prompts (optional)...",
show_label=False,
lines=3,
value="3d, cartoon, anime, (deformed eyes, nose, ears, nose), bad anatomy, ugly"
)
text_button = gr.Button("Generate", variant='primary', elem_id="generate")
with gr.Column(scale=1, min_width=300, elem_classes="setting-group"):
model = gr.Dropdown(
interactive=True,
value="sd_xl_base_1.0.safetensors [be9edd61]",
show_label=True,
label="Model",
choices=prodia_client.list_models()
)
sampler = gr.Dropdown(
value="DPM++ 2M Karras",
show_label=True,
label="Sampling Method",
choices=prodia_client.list_samplers()
)
steps = gr.Slider(label="Sampling Steps", minimum=1, maximum=25, value=20, step=1)
with gr.Column(scale=1, min_width=300, elem_classes="setting-group"):
width = gr.Slider(label="Width", minimum=512, maximum=1536, value=1024, step=8)
height = gr.Slider(label="Height", minimum=512, maximum=1536, value=1024, step=8)
cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, value=7, step=1)
seed = gr.Number(label="Seed", value=-1)
text_button.click(flip_text, inputs=[prompt, negative_prompt, model, steps, sampler, cfg_scale, width, height, seed], outputs=image_output)
start_button.click(fn=lambda: "Main Generation Screen", inputs=None, outputs=state)
state.change(fn=lambda x: gr.update(visible=(x == "Main Generation Screen")), inputs=state, outputs=tabs)
# Launch the Gradio app
demo.launch(max_threads=128)