from typing import Literal from pathlib import Path import uuid import json import re import asyncio import toml import torch from compel import Compel from diffusers import ( DiffusionPipeline, StableDiffusionPipeline, AutoencoderKL, DPMSolverMultistepScheduler, DDPMScheduler, DPMSolverSinglestepScheduler, DPMSolverSDEScheduler, DEISMultistepScheduler, ) from .utils import ( set_all_seeds, ) from .palmchat import ( palm_prompts, gen_text, ) _gpus = 0 class ImageMaker: # TODO: DocString... """Class for generating images from prompts.""" __ratio = {'3:2': [768, 512], '4:3': [680, 512], '16:9': [912, 512], '1:1': [512, 512], '9:16': [512, 912], '3:4': [512, 680], '2:3': [512, 768]} __allocated = False def __init__(self, model_base: str, clip_skip: int = 2, sampling: Literal['sde-dpmsolver++'] = 'sde-dpmsolver++', vae: str = None, safety: bool = True, neg_prompt: str = None, device: str = None) -> None: """Initialize the ImageMaker class. Args: model_base (str): Filename of the model base. clip_skip (int, optional): Number of layers to skip in the clip model. Defaults to 2. sampling (Literal['sde-dpmsolver++'], optional): Sampling method. Defaults to 'sde-dpmsolver++'. vae (str, optional): Filename of the VAE model. Defaults to None. safety (bool, optional): Whether to use the safety checker. Defaults to True. device (str, optional): Device to use for the model. Defaults to None. """ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if not device else device self.__model_base = model_base self.__clip_skip = clip_skip self.__sampling = sampling self.__vae = vae self.__safety = safety self.neg_prompt = neg_prompt print("Loading the Stable Diffusion model into memory...") self.__sd_model = StableDiffusionPipeline.from_single_file(self.model_base, torch_dtype=torch.float16, use_safetensors=True, ) # Clip Skip self.__sd_model.text_encoder.text_model.encoder.layers = self.__sd_model.text_encoder.text_model.encoder.layers[:12 - (self.clip_skip - 1)] # Sampling method if True: # TODO: Sampling method :: self.sampling == 'sde-dpmsolver++' scheduler = DPMSolverMultistepScheduler.from_config(self.__sd_model.scheduler.config) scheduler.config.algorithm_type = 'sde-dpmsolver++' self.__sd_model.scheduler = scheduler # TODO: Use LoRA # VAE if self.vae: vae_model = AutoencoderKL.from_pretrained(self.vae, torch_dtype=torch.float16) self.__sd_model.vae = vae_model if not self.safety: self.__sd_model.safety_checker = None self.__sd_model.requires_safety_checker = False print(f"Loaded model to {self.device}") self.__sd_model = self.__sd_model.to(self.device) # Text Encoder using Compel self.__compel_proc = Compel(tokenizer=self.__sd_model.tokenizer, text_encoder=self.__sd_model.text_encoder, truncate_long_prompts=False) output_dir = Path('.') / 'outputs' if not output_dir.exists(): output_dir.mkdir(parents=True, exist_ok=True) elif output_dir.is_file(): assert False, f"A file with the same name as the desired directory ('{str(output_dir)}') already exists." def text2image(self, prompt: str, neg_prompt: str = None, ratio: Literal['3:2', '4:3', '16:9', '1:1', '9:16', '3:4', '2:3'] = '1:1', step: int = 28, cfg: float = 4.5, seed: int = None) -> str: """Generate an image from the prompt. Args: prompt (str): Prompt for the image generation. neg_prompt (str, optional): Negative prompt for the image generation. Defaults to None. ratio (Literal['3:2', '4:3', '16:9', '1:1', '9:16', '3:4', '2:3'], optional): Ratio of the generated image. Defaults to '1:1'. step (int, optional): Number of iterations for the diffusion. Defaults to 20. cfg (float, optional): Configuration for the diffusion. Defaults to 7.5. seed (int, optional): Seed for the random number generator. Defaults to None. Returns: str: Path to the generated image. """ output_filename = Path('.') / 'outputs' / str(uuid.uuid4()) if not seed or seed == -1: seed = torch.randint(0, 2**32 - 1, (1,)).item() set_all_seeds(seed) width, height = self.__ratio[ratio] prompt_embeds, negative_prompt_embeds = self.__get_pipeline_embeds(prompt, neg_prompt or self.neg_prompt) # Generate the image result = self.__sd_model(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, guidance_scale=cfg, num_inference_steps=step, width=width, height=height, ) if self.__safety and result.nsfw_content_detected[0]: print("=== NSFW Content Detected ===") raise ValueError("Potential NSFW content was detected in one or more images.") img = result.images[0] img.save(str(output_filename.with_suffix('.png'))) return str(output_filename.with_suffix('.png')) def generate_character_prompts(self, character_name: str, age: str, job: str, keywords: list[str] = None, creative_mode: Literal['sd character', 'cartoon', 'realistic'] = 'cartoon') -> tuple[str, str]: """Generate positive and negative prompts for a character based on given attributes. Args: character_name (str): Character's name. age (str): Age of the character. job (str): The profession or job of the character. keywords (list[str]): List of descriptive words for the character. Returns: tuple[str, str]: A tuple of positive and negative prompts. """ positive = "" # add static prompt for character if needed (e.g. "chibi, cute, anime") negative = palm_prompts['image_gen']['neg_prompt'] # Generate prompts with PaLM t = palm_prompts['image_gen']['character']['gen_prompt'] q = palm_prompts['image_gen']['character']['query'] query_string = t.format(input=q.format(character_name=character_name, job=job, age=age, keywords=', '.join(keywords) if keywords else 'Nothing')) try: response, response_txt = asyncio.run(asyncio.wait_for( gen_text(query_string, mode="text", use_filter=False), timeout=10) ) except asyncio.TimeoutError: raise TimeoutError("The response time for PaLM API exceeded the limit.") try: res_json = json.loads(response_txt) positive = (res_json['primary_sentence'] if not positive else f"{positive}, {res_json['primary_sentence']}") + ", " gender_keywords = ['1man', '1woman', '1boy', '1girl', '1male', '1female', '1gentleman', '1lady'] positive += ', '.join([w if w not in gender_keywords else w + '+++' for w in res_json['descriptors']]) positive = f'{job.lower()}+'.join(positive.split(job.lower())) except: print("=== PaLM Response ===") print(response.filters) print(response_txt) print("=== PaLM Response ===") raise ValueError("The response from PaLM API is not in the expected format.") return (positive.lower(), negative.lower()) def generate_background_prompts(self, genre:str, place:str, mood:str, title:str, chapter_title:str, chapter_plot:str) -> tuple[str, str]: """Generate positive and negative prompts for a background image based on given attributes. Args: genre (str): Genre of the story. place (str): Place of the story. mood (str): Mood of the story. title (str): Title of the story. chapter_title (str): Title of the chapter. chapter_plot (str): Plot of the chapter. Returns: tuple[str, str]: A tuple of positive and negative prompts. """ positive = "painting+++, anime+, catoon, watercolor, wallpaper, text---" # add static prompt for background if needed (e.g. "chibi, cute, anime") negative = "realistic, human, character, people, photograph, 3d render, blurry, grayscale, oversaturated, " + palm_prompts['image_gen']['neg_prompt'] # Generate prompts with PaLM t = palm_prompts['image_gen']['background']['gen_prompt'] q = palm_prompts['image_gen']['background']['query'] query_string = t.format(input=q.format(genre=genre, place=place, mood=mood, title=title, chapter_title=chapter_title, chapter_plot=chapter_plot)) try: response, response_txt = asyncio.run(asyncio.wait_for( gen_text(query_string, mode="text", use_filter=False), timeout=10) ) except asyncio.TimeoutError: raise TimeoutError("The response time for PaLM API exceeded the limit.") try: res_json = json.loads(response_txt) positive = (res_json['primary_sentence'] if not positive else f"{positive}, {res_json['primary_sentence']}") + ", " positive += ', '.join(res_json['descriptors']) except: print("=== PaLM Response ===") print(response.filters) print(response_txt) print("=== PaLM Response ===") raise ValueError("The response from PaLM API is not in the expected format.") return (positive.lower(), negative.lower()) def __get_pipeline_embeds(self, prompt:str, negative_prompt:str) -> tuple[torch.Tensor, torch.Tensor]: """ Get pipeline embeds for prompts bigger than the maxlength of the pipeline Args: prompt (str): Prompt for the image generation. neg_prompt (str): Negative prompt for the image generation. Returns: tuple[torch.Tensor, torch.Tensor]: A tuple of positive and negative prompt embeds. """ conditioning = self.__compel_proc.build_conditioning_tensor(prompt) negative_conditioning = self.__compel_proc.build_conditioning_tensor(negative_prompt) return self.__compel_proc.pad_conditioning_tensors_to_same_length([conditioning, negative_conditioning]) @property def model_base(self): """Model base Returns: str: The model base (read-only) """ return self.__model_base @property def clip_skip(self): """Clip Skip Returns: int: The number of layers to skip in the clip model (read-only) """ return self.__clip_skip @property def sampling(self): """Sampling method Returns: Literal['sde-dpmsolver++']: The sampling method (read-only) """ return self.__sampling @property def vae(self): """VAE Returns: str: The VAE (read-only) """ return self.__vae @property def safety(self): """Safety checker Returns: bool: Whether to use the safety checker (read-only) """ return self.__safety @property def device(self): """Device Returns: str: The device (read-only) """ return self.__device @device.setter def device(self, value): if self.__allocated: raise RuntimeError("Cannot change device after the model is loaded.") if value == 'cpu': self.__device = value else: global _gpus self.__device = f'{value}:{_gpus}' max_gpu = torch.cuda.device_count() _gpus = (_gpus + 1) if (_gpus + 1) < max_gpu else 0 self.__allocated = True @property def neg_prompt(self): """Negative prompt Returns: str: The negative prompt """ return self.__neg_prompt @neg_prompt.setter def neg_prompt(self, value): if not value: self.__neg_prompt = "" else: self.__neg_prompt = value