Spaces:
Runtime error
Runtime error
File size: 3,864 Bytes
4fb3c5e 6f3a230 4fb3c5e 6f3a230 4fb3c5e 6f3a230 4fb3c5e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
from __future__ import annotations
import logging
import os
import sys
import PIL.Image
import torch
from diffusers import (DDIMPipeline, DDIMScheduler, DDPMPipeline,
DiffusionPipeline, PNDMPipeline, PNDMScheduler)
HF_TOKEN = os.environ['HF_TOKEN']
formatter = logging.Formatter(
'[%(asctime)s] %(name)s %(levelname)s: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setLevel(logging.INFO)
stream_handler.setFormatter(formatter)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.propagate = False
logger.addHandler(stream_handler)
class Model:
MODEL_NAMES = [
'ddpm-128-exp000',
]
def __init__(self, device: str | torch.device):
self.device = torch.device(device)
self._download_all_models()
self.model_name = self.MODEL_NAMES[0]
self.scheduler_type = 'DDIM'
self.pipeline = self._load_pipeline(self.model_name,
self.scheduler_type)
def _load_pipeline(self, model_name: str,
scheduler_type: str) -> DiffusionPipeline:
repo_id = f'hysts/diffusers-anime-faces-{model_name}'
if scheduler_type == 'DDPM':
pipeline = DDPMPipeline.from_pretrained(repo_id,
use_auth_token=HF_TOKEN)
elif scheduler_type == 'DDIM':
pipeline = DDIMPipeline.from_pretrained(repo_id,
use_auth_token=HF_TOKEN)
config, _ = DDIMScheduler.extract_init_dict(
dict(pipeline.scheduler.config))
pipeline.scheduler = DDIMScheduler(**config)
elif scheduler_type == 'PNDM':
pipeline = PNDMPipeline.from_pretrained(repo_id,
use_auth_token=HF_TOKEN)
config, _ = PNDMScheduler.extract_init_dict(
dict(pipeline.scheduler.config))
pipeline.scheduler = PNDMScheduler(**config)
else:
raise ValueError
return pipeline
def set_pipeline(self, model_name: str, scheduler_type: str) -> None:
logger.info('--- set_pipeline ---')
logger.info(f'{model_name=}, {scheduler_type=}')
if model_name == self.model_name and scheduler_type == self.scheduler_type:
logger.info('Skipping')
logger.info('--- done ---')
return
self.model_name = model_name
self.scheduler_type = scheduler_type
self.pipeline = self._load_pipeline(model_name, scheduler_type)
logger.info('--- done ---')
def _download_all_models(self) -> None:
for name in self.MODEL_NAMES:
self._load_pipeline(name, 'DDPM')
def generate(self, seed: int, num_steps: int) -> PIL.Image.Image:
logger.info('--- generate ---')
logger.info(f'{seed=}, {num_steps=}')
torch.manual_seed(seed)
if self.scheduler_type == 'DDPM':
res = self.pipeline(batch_size=1,
torch_device=self.device)['sample'][0]
elif self.scheduler_type in ['DDIM', 'PNDM']:
res = self.pipeline(batch_size=1,
torch_device=self.device,
num_inference_steps=num_steps)['sample'][0]
else:
raise ValueError
logger.info('--- done ---')
return res
def run(
self,
model_name: str,
scheduler_type: str,
num_steps: int,
seed: int,
) -> PIL.Image.Image:
self.set_pipeline(model_name, scheduler_type)
if scheduler_type == 'PNDM':
num_steps = max(4, min(num_steps, 100))
return self.generate(seed, num_steps)
|