Spaces:
Runtime error
Runtime error
File size: 4,838 Bytes
4fb3c5e cb229bd 4fb3c5e cb229bd 4fb3c5e 6f3a230 4fb3c5e cb229bd 4fb3c5e 6f3a230 4fb3c5e 6f3a230 4fb3c5e cb229bd 4fb3c5e cb229bd 4fb3c5e cb229bd 4fb3c5e cb229bd 4fb3c5e cb229bd b6c80e7 cb229bd 0dc864d cb229bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
from __future__ import annotations
import logging
import os
import random
import sys
import numpy as np
import PIL.Image
import torch
from diffusers import (DDIMPipeline, DDIMScheduler, DDPMPipeline,
DiffusionPipeline, PNDMPipeline, PNDMScheduler)
HF_TOKEN = os.environ['HF_TOKEN']
formatter = logging.Formatter(
'[%(asctime)s] %(name)s %(levelname)s: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setLevel(logging.INFO)
stream_handler.setFormatter(formatter)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.propagate = False
logger.addHandler(stream_handler)
class Model:
MODEL_NAMES = [
'ddpm-128-exp000',
]
def __init__(self, device: str | torch.device):
self.device = torch.device(device)
self._download_all_models()
self.model_name = self.MODEL_NAMES[0]
self.scheduler_type = 'DDIM'
self.pipeline = self._load_pipeline(self.model_name,
self.scheduler_type)
self.rng = random.Random()
def _load_pipeline(self, model_name: str,
scheduler_type: str) -> DiffusionPipeline:
repo_id = f'hysts/diffusers-anime-faces-{model_name}'
if scheduler_type == 'DDPM':
pipeline = DDPMPipeline.from_pretrained(repo_id,
use_auth_token=HF_TOKEN)
elif scheduler_type == 'DDIM':
pipeline = DDIMPipeline.from_pretrained(repo_id,
use_auth_token=HF_TOKEN)
config, _ = DDIMScheduler.extract_init_dict(
dict(pipeline.scheduler.config))
pipeline.scheduler = DDIMScheduler(**config)
elif scheduler_type == 'PNDM':
pipeline = PNDMPipeline.from_pretrained(repo_id,
use_auth_token=HF_TOKEN)
config, _ = PNDMScheduler.extract_init_dict(
dict(pipeline.scheduler.config))
pipeline.scheduler = PNDMScheduler(**config)
else:
raise ValueError
return pipeline
def set_pipeline(self, model_name: str, scheduler_type: str) -> None:
logger.info('--- set_pipeline ---')
logger.info(f'{model_name=}, {scheduler_type=}')
if model_name == self.model_name and scheduler_type == self.scheduler_type:
logger.info('Skipping')
logger.info('--- done ---')
return
self.model_name = model_name
self.scheduler_type = scheduler_type
self.pipeline = self._load_pipeline(model_name, scheduler_type)
logger.info('--- done ---')
def _download_all_models(self) -> None:
for name in self.MODEL_NAMES:
self._load_pipeline(name, 'DDPM')
def generate(self,
seed: int,
num_steps: int,
num_images: int = 1) -> list[PIL.Image.Image]:
logger.info('--- generate ---')
logger.info(f'{seed=}, {num_steps=}')
torch.manual_seed(seed)
if self.scheduler_type == 'DDPM':
res = self.pipeline(batch_size=num_images,
torch_device=self.device)['sample']
elif self.scheduler_type in ['DDIM', 'PNDM']:
res = self.pipeline(batch_size=num_images,
torch_device=self.device,
num_inference_steps=num_steps)['sample']
else:
raise ValueError
logger.info('--- done ---')
return res
def run(
self,
model_name: str,
scheduler_type: str,
num_steps: int,
seed: int,
) -> PIL.Image.Image:
self.set_pipeline(model_name, scheduler_type)
if scheduler_type == 'PNDM':
num_steps = max(4, min(num_steps, 100))
return self.generate(seed, num_steps)[0]
@staticmethod
def to_grid(images: list[PIL.Image.Image],
ncols: int = 2) -> PIL.Image.Image:
images = [np.asarray(image) for image in images]
nrows = (len(images) + ncols - 1) // ncols
h, w = images[0].shape[:2]
d = nrows * ncols - len(images)
if d > 0:
images += [np.full((h, w, 3), 255, dtype=np.uint8)] * d
grid = np.asarray(images).reshape(nrows, ncols, h, w, 3).transpose(
0, 2, 1, 3, 4).reshape(nrows * h, ncols * w, 3)
return PIL.Image.fromarray(grid)
def run_simple(self) -> PIL.Image.Image:
self.set_pipeline(self.MODEL_NAMES[0], 'DDIM')
seed = self.rng.randint(0, 1000000)
images = self.generate(seed, num_steps=10, num_images=4)
return self.to_grid(images, 2)
|