creativity_hub / text_to_image.py
joyson's picture
Upload 5 files
9d9968c verified
raw
history blame
1.74 kB
import torch
import spaces
from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
from PIL import Image
from io import BytesIO
from utils import load_unet_model
@spaces.GPU
class TextToImage:
"""
Class to handle Text-to-Image generation using Stable Diffusion XL.
"""
def __init__(self, device="cpu"):
# Model and repository details
self.base = "stabilityai/stable-diffusion-xl-base-1.0"
self.repo = "ByteDance/SDXL-Lightning"
self.ckpt = "sdxl_lightning_4step_unet.safetensors"
self.device = device
# Load the UNet model
print("Loading Text-to-Image model...")
self.unet = load_unet_model(self.base, self.repo, self.ckpt, device=self.device)
# Initialize the pipeline
self.pipe = StableDiffusionXLPipeline.from_pretrained(
self.base,
unet=self.unet,
torch_dtype=torch.float32,
).to(self.device)
# Set the scheduler
self.pipe.scheduler = EulerDiscreteScheduler.from_config(
self.pipe.scheduler.config,
timestep_spacing="trailing"
)
print("Text-to-Image model loaded successfully.")
async def generate_image(self, prompt):
"""
Generate an image from a text prompt.
Args:
prompt (str): The text prompt to generate the image.
Returns:
PIL.Image: The generated image.
"""
with torch.no_grad():
image = self.pipe(
prompt,
num_inference_steps=4,
guidance_scale=0
).images[0]
return image