Spaces:
Sleeping
Sleeping
File size: 3,109 Bytes
e199da3 dee6e4c affc855 e199da3 0f197c4 5a982ce a7a041a ea8a0d1 e199da3 0f197c4 ea8a0d1 e199da3 dee6e4c c064275 dee6e4c e199da3 dee6e4c e199da3 dee6e4c e199da3 0f197c4 e199da3 dee6e4c 8d29889 dee6e4c 5a982ce dee6e4c e199da3 dee6e4c e199da3 ea8a0d1 a7a041a e199da3 ea8a0d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import os
from io import BytesIO
from PIL import Image
from diffusers import AutoPipelineForText2Image
import gradio as gr
import base64
from generate_prompts import generate_prompt
CONCURRENCY_LIMIT = 10
def load_model():
print("Loading the Stable Diffusion model...")
try:
model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
print("Model loaded successfully.")
return model
except Exception as e:
print(f"Error loading model: {e}")
return None
def generate_image(prompt):
model = load_model()
try:
if model is None:
raise ValueError("Model not loaded properly.")
print(f"Generating image with prompt: {prompt}")
output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
print(f"Model output: {output}")
if output is None:
raise ValueError("Model returned None")
if hasattr(output, 'images') and output.images:
print(f"Image generated successfully")
image = output.images[0]
buffered = BytesIO()
image.save(buffered, format="JPEG")
image_bytes = buffered.getvalue()
img_str = base64.b64encode(image_bytes).decode("utf-8")
print("Image encoded to base64")
print(f'img_str: {img_str[:100]}...') # Print a snippet of the base64 string
return img_str, None
else:
print(f"No images found in model output")
raise ValueError("No images found in model output")
except Exception as e:
print(f"An error occurred while generating image: {e}")
return None, str(e)
def inference(sentence_mapping, character_dict, selected_style):
try:
print(f"Received sentence_mapping: {sentence_mapping}, type: {type(sentence_mapping)}")
print(f"Received character_dict: {character_dict}, type: {type(character_dict)}")
print(f"Received selected_style: {selected_style}, type: {type(selected_style)}")
images = {}
for paragraph_number, sentences in sentence_mapping.items():
combined_sentence = " ".join(sentences)
prompt = generate_prompt(combined_sentence,character_dict, selected_style)
print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
img_str, error = generate_image(prompt)
if error:
images[paragraph_number] = f"Error: {error}"
else:
images[paragraph_number] = img_str
return images
except Exception as e:
print(f"An error occurred during inference: {e}")
return {"error": str(e)}
gradio_interface = gr.Interface(
fn=inference,
inputs=[
gr.JSON(label="Sentence Mapping"),
gr.JSON(label="Character Dict"),
gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
],
outputs="json",
concurrency_limit=CONCURRENCY_LIMIT)
if __name__ == "__main__":
print("Launching Gradio interface...")
gradio_interface.launch()
|