from __future__ import annotations import PIL.Image import torch from diffusers import (StableDiffusionAttendAndExcitePipeline, StableDiffusionPipeline) class Model: def __init__(self): self.device = torch.device( 'cuda:0' if torch.cuda.is_available() else 'cpu') model_id = 'CompVis/stable-diffusion-v1-4' self.ax_pipe = StableDiffusionAttendAndExcitePipeline.from_pretrained( model_id) self.ax_pipe.to(self.device) self.sd_pipe = StableDiffusionPipeline.from_pretrained(model_id) self.sd_pipe.to(self.device) def get_token_table(self, prompt: str): tokens = [ self.ax_pipe.tokenizer.decode(t) for t in self.ax_pipe.tokenizer(prompt)['input_ids'] ] tokens = tokens[1:-1] return list(enumerate(tokens, start=1)) def run( self, prompt: str, indices_to_alter_str: str, seed: int = 0, apply_attend_and_excite: bool = True, num_steps: int = 50, guidance_scale: float = 7.5, scale_factor: int = 20, thresholds: dict[int, float] = { 10: 0.5, 20: 0.8, }, max_iter_to_alter: int = 25, ) -> PIL.Image.Image: generator = torch.Generator(device=self.device).manual_seed(seed) if apply_attend_and_excite: try: token_indices = list(map(int, indices_to_alter_str.split(','))) except Exception: raise ValueError('Invalid token indices.') out = self.ax_pipe( prompt=prompt, token_indices=token_indices, guidance_scale=guidance_scale, generator=generator, num_inference_steps=num_steps, max_iter_to_alter=max_iter_to_alter, thresholds=thresholds, scale_factor=scale_factor, ) else: out = self.sd_pipe( prompt=prompt, guidance_scale=guidance_scale, generator=generator, num_inference_steps=num_steps, ) return out.images[0]