Spaces:
Runtime error
Runtime error
import torch | |
import spaces | |
from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler | |
from PIL import Image | |
from io import BytesIO | |
from utils import load_unet_model | |
class TextToImage: | |
""" | |
Class to handle Text-to-Image generation using Stable Diffusion XL. | |
""" | |
def __init__(self, device="cpu"): | |
# Model and repository details | |
self.base = "stabilityai/stable-diffusion-xl-base-1.0" | |
self.repo = "ByteDance/SDXL-Lightning" | |
self.ckpt = "sdxl_lightning_4step_unet.safetensors" | |
self.device = device | |
# Load the UNet model | |
print("Loading Text-to-Image model...") | |
self.unet = load_unet_model(self.base, self.repo, self.ckpt, device=self.device) | |
# Initialize the pipeline | |
self.pipe = StableDiffusionXLPipeline.from_pretrained( | |
self.base, | |
unet=self.unet, | |
torch_dtype=torch.float32, | |
).to(self.device) | |
# Set the scheduler | |
self.pipe.scheduler = EulerDiscreteScheduler.from_config( | |
self.pipe.scheduler.config, | |
timestep_spacing="trailing" | |
) | |
print("Text-to-Image model loaded successfully.") | |
async def generate_image(self, prompt): | |
""" | |
Generate an image from a text prompt. | |
Args: | |
prompt (str): The text prompt to generate the image. | |
Returns: | |
PIL.Image: The generated image. | |
""" | |
with torch.no_grad(): | |
image = self.pipe( | |
prompt, | |
num_inference_steps=4, | |
guidance_scale=0 | |
).images[0] | |
return image | |