import gradio as gr from diffusers import DPMSolverMultistepScheduler, AutoencoderKL, UNet2DConditionModel from transformers import CLIPTextModel, CLIPTokenizer import torch from tqdm.auto import tqdm from time import time from PIL import Image vae = AutoencoderKL.from_pretrained("KBlueLeaf/Kohaku-XL-Epsilon-rev3", subfolder="vae", allow_pickle=True) tokenizer = CLIPTokenizer.from_pretrained("KBlueLeaf/Kohaku-XL-Epsilon-rev3", subfolder="tokenizer", allow_pickle=True) textEncoder = CLIPTextModel.from_pretrained("KBlueLeaf/Kohaku-XL-Epsilon-rev3", subfolder="text_encoder", allow_pickle=True) unet = UNet2DConditionModel.from_pretrained("KBlueLeaf/Kohaku-XL-Epsilon-rev3", subfolder="unet", allow_pickle=True) scheduler = DPMSolverMultistepScheduler.from_pretrained("KBlueLeaf/Kohaku-XL-Epsilon-rev3", subfolder="scheduler", allow_pickle=True) torchDevice = "cuda" vae.to(torchDevice) textEncoder.to(torchDevice) unet.to(torchDevice) def generate(prompt: str, negativePrompt: str, steps: int, cfg: float, seed: int, randomized: bool, width: int, height: int): generator = torch.manual_seed(time()) if randomized: seed = torch.randint(10000, 9223372036854776000, (1,))[0] batchSize = len(prompt) textInput = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") with torch.no_grad(): textEmbeddings = textEncoder(textInput.input_ids.to(torchDevice), attention_mask=textInput.attention_mask.to(torchDevice))[0] maxLength = textInput.input_ids.shape[-1] unconditionedInput = tokenizer([""] * batchSize, padding="max_length", max_length=maxLength, return_tensors="pt") unconditionedEmbeddings = textEncoder(unconditionedInput.input_ids.to(torchDevice))[0] textEmbeddings = torch.cat([unconditionedEmbeddings, textEmbeddings]) latents = torch.randn((batchSize, unet.config.in_channels, height // 8, width // 8), generator=generator, device=torchDevice) latents = latents * scheduler.init_noise_sigma scheduler.set_timesteps(steps) for t in tqdm(scheduler.timesteps): latentModelInput = torch.cat([latents] * 2) latentModelInput = scheduler.scale_model_input(latentModelInput, timestep=t) with torch.no_grad(): noisePred = unet(latentModelInput, t, encoder_hidden_states=textEmbeddings).sample unconditionedNoisePred, noisePredText = noisePred.chunk(2) noisePred = unconditionedNoisePred + cfg * (noisePredText - unconditionedNoisePred) latents = scheduler.step(noisePred, t, latents).prev_sample latents = 1 / 0.18215 * latents with torch.no_grad(): image = vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1).squeeze() image = (image.permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy() images = (image * 255).round().astype("uint8") return Image.fromarray(images) interface = gr.Interface(fn=generate, inputs=[ gr.Textbox(lines=3, placeholder="Prompt is here...", label="Prompt"), gr.Textbox(lines=3, placeholder="Negative prompt is here...", label="Negative Prompt"), gr.Slider(0, 1000, step=1, label="Steps", value=20), gr.Slider(0, 50, step=0.1, label="CFG Scale", value=8), gr.Number(label="Seed", value=0), gr.Checkbox(label="Randomize Seed", value=True), gr.Slider(256, 999999, step=64, label="Width", value=512), gr.Slider(256, 999999, step=64, label="Height", value=512), ], outputs="image") if __name__ == "__main__": interface.launch()