Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import PIL.Image | |
import torch | |
from diffusers import (StableDiffusionAttendAndExcitePipeline, | |
StableDiffusionPipeline) | |
class Model: | |
def __init__(self): | |
self.device = torch.device( | |
'cuda:0' if torch.cuda.is_available() else 'cpu') | |
model_id = 'CompVis/stable-diffusion-v1-4' | |
self.ax_pipe = StableDiffusionAttendAndExcitePipeline.from_pretrained( | |
model_id) | |
self.ax_pipe.to(self.device) | |
self.sd_pipe = StableDiffusionPipeline.from_pretrained(model_id) | |
self.sd_pipe.to(self.device) | |
def get_token_table(self, prompt: str): | |
tokens = [ | |
self.ax_pipe.tokenizer.decode(t) | |
for t in self.ax_pipe.tokenizer(prompt)['input_ids'] | |
] | |
tokens = tokens[1:-1] | |
return list(enumerate(tokens, start=1)) | |
def run( | |
self, | |
prompt: str, | |
indices_to_alter_str: str, | |
seed: int = 0, | |
apply_attend_and_excite: bool = True, | |
num_steps: int = 50, | |
guidance_scale: float = 7.5, | |
scale_factor: int = 20, | |
thresholds: dict[int, float] = { | |
10: 0.5, | |
20: 0.8, | |
}, | |
max_iter_to_alter: int = 25, | |
) -> PIL.Image.Image: | |
generator = torch.Generator(device=self.device).manual_seed(seed) | |
if apply_attend_and_excite: | |
try: | |
token_indices = list(map(int, indices_to_alter_str.split(','))) | |
except Exception: | |
raise ValueError('Invalid token indices.') | |
out = self.ax_pipe( | |
prompt=prompt, | |
token_indices=token_indices, | |
guidance_scale=guidance_scale, | |
generator=generator, | |
num_inference_steps=num_steps, | |
max_iter_to_alter=max_iter_to_alter, | |
thresholds=thresholds, | |
scale_factor=scale_factor, | |
) | |
else: | |
out = self.sd_pipe( | |
prompt=prompt, | |
guidance_scale=guidance_scale, | |
generator=generator, | |
num_inference_steps=num_steps, | |
) | |
return out.images[0] | |