text2image_3 / app.py
RanM's picture
Update app.py
5a982ce verified
import os
from io import BytesIO
from PIL import Image
from diffusers import AutoPipelineForText2Image
import gradio as gr
import base64
from generate_prompts import generate_prompt
CONCURRENCY_LIMIT = 10
def load_model():
print("Loading the Stable Diffusion model...")
try:
model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
print("Model loaded successfully.")
return model
except Exception as e:
print(f"Error loading model: {e}")
return None
def generate_image(prompt):
model = load_model()
try:
if model is None:
raise ValueError("Model not loaded properly.")
print(f"Generating image with prompt: {prompt}")
output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
print(f"Model output: {output}")
if output is None:
raise ValueError("Model returned None")
if hasattr(output, 'images') and output.images:
print(f"Image generated successfully")
image = output.images[0]
buffered = BytesIO()
image.save(buffered, format="JPEG")
image_bytes = buffered.getvalue()
img_str = base64.b64encode(image_bytes).decode("utf-8")
print("Image encoded to base64")
print(f'img_str: {img_str[:100]}...') # Print a snippet of the base64 string
return img_str, None
else:
print(f"No images found in model output")
raise ValueError("No images found in model output")
except Exception as e:
print(f"An error occurred while generating image: {e}")
return None, str(e)
def inference(sentence_mapping, character_dict, selected_style):
try:
print(f"Received sentence_mapping: {sentence_mapping}, type: {type(sentence_mapping)}")
print(f"Received character_dict: {character_dict}, type: {type(character_dict)}")
print(f"Received selected_style: {selected_style}, type: {type(selected_style)}")
images = {}
for paragraph_number, sentences in sentence_mapping.items():
combined_sentence = " ".join(sentences)
prompt = generate_prompt(combined_sentence,character_dict, selected_style)
print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
img_str, error = generate_image(prompt)
if error:
images[paragraph_number] = f"Error: {error}"
else:
images[paragraph_number] = img_str
return images
except Exception as e:
print(f"An error occurred during inference: {e}")
return {"error": str(e)}
gradio_interface = gr.Interface(
fn=inference,
inputs=[
gr.JSON(label="Sentence Mapping"),
gr.JSON(label="Character Dict"),
gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
],
outputs="json",
concurrency_limit=CONCURRENCY_LIMIT)
if __name__ == "__main__":
print("Launching Gradio interface...")
gradio_interface.launch()