text2image_1 / app.py
RanM's picture
Update app.py
e6a70ca verified
raw
history blame
3.22 kB
import os
from io import BytesIO
from PIL import Image
from diffusers import AutoPipelineForText2Image
import gradio as gr
import base64
CONCURRENCY_LIMIT = 5
def load_model():
print("Loading the Stable Diffusion model...")
try:
model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
print("Model loaded successfully.")
return model
except Exception as e:
print(f"Error loading model: {e}")
return None
def generate_image(prompt):
model = load_model()
try:
if model is None:
raise ValueError("Model not loaded properly.")
print(f"Generating image with prompt: {prompt}")
output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
print(f"Model output: {output}")
if output is None:
raise ValueError("Model returned None")
if hasattr(output, 'images') and output.images:
print(f"Image generated successfully")
image = output.images[0]
buffered = BytesIO()
image.save(buffered, format="JPEG")
image_bytes = buffered.getvalue()
img_str = base64.b64encode(image_bytes).decode("utf-8")
print("Image encoded to base64")
print(f'img_str: {img_str[:100]}...') # Print a snippet of the base64 string
return img_str, None
else:
print(f"No images found in model output")
raise ValueError("No images found in model output")
except Exception as e:
print(f"An error occurred while generating image: {e}")
return None, str(e)
def inference(sentence_mapping, character_dict, selected_style):
try:
print(f"Received sentence_mapping: {sentence_mapping}, type: {type(sentence_mapping)}")
print(f"Received character_dict: {character_dict}, type: {type(character_dict)}")
print(f"Received selected_style: {selected_style}, type: {type(selected_style)}")
if sentence_mapping is None or character_dict is None or selected_style is None:
return {"error": "One or more inputs are None"}
images = {}
for paragraph_number, sentences in sentence_mapping.items():
combined_sentence = " ".join(sentences)
prompt = f"Make an illustration in {selected_style} style from: {combined_sentence}"
print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
img_str, error = generate_image(prompt)
if error:
images[paragraph_number] = f"Error: {error}"
else:
images[paragraph_number] = img_str
return images
except Exception as e:
print(f"An error occurred during inference: {e}")
return {"error": str(e)}
gradio_interface = gr.Interface(
fn=inference,
inputs=[
gr.JSON(label="Sentence Mapping"),
gr.JSON(label="Character Dict"),
gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
],
outputs="json",
concurrency_limit=CONCURRENCY_LIMIT)
if __name__ == "__main__":
print("Launching Gradio interface...")
gradio_interface.launch()