hysts's picture
hysts HF staff
Add files
4fb3c5e
raw
history blame
3.85 kB
from __future__ import annotations
import logging
import os
import sys
import PIL.Image
import torch
from diffusers import (DDIMPipeline, DDIMScheduler, DDPMPipeline, PNDMPipeline,
PNDMScheduler)
HF_TOKEN = os.environ['HF_TOKEN']
formatter = logging.Formatter(
'[%(asctime)s] %(name)s %(levelname)s: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setLevel(logging.INFO)
stream_handler.setFormatter(formatter)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.propagate = False
logger.addHandler(stream_handler)
class Model:
MODEL_NAMES = [
'ddpm-128-exp000',
]
def __init__(self, device: str | torch.device):
self.device = torch.device(device)
self._download_all_models()
self.model_name = self.MODEL_NAMES[0]
self.scheduler_type = 'DDIM'
self.pipeline = self._load_pipeline(self.model_name,
self.scheduler_type)
def _load_pipeline(self, model_name: str,
scheduler_type: str) -> DDIMPipeline | DDPMPipeline:
repo_id = f'hysts/diffusers-anime-faces-{model_name}'
if scheduler_type == 'DDPM':
pipeline = DDPMPipeline.from_pretrained(repo_id,
use_auth_token=HF_TOKEN)
elif scheduler_type == 'DDIM':
pipeline = DDIMPipeline.from_pretrained(repo_id,
use_auth_token=HF_TOKEN)
config, _ = DDIMScheduler.extract_init_dict(
dict(pipeline.scheduler.config))
pipeline.scheduler = DDIMScheduler(**config)
elif scheduler_type == 'PNDM':
pipeline = PNDMPipeline.from_pretrained(repo_id,
use_auth_token=HF_TOKEN)
config, _ = PNDMScheduler.extract_init_dict(
dict(pipeline.scheduler.config))
pipeline.scheduler = PNDMScheduler(**config)
else:
raise ValueError
return pipeline
def set_pipeline(self, model_name: str, scheduler_type: str) -> None:
logger.info('--- set_pipeline ---')
logger.info(f'{model_name=}, {scheduler_type=}')
if model_name == self.model_name and scheduler_type == self.scheduler_type:
logger.info('Skipping')
logger.info('--- done ---')
return
self.model_name = model_name
self.scheduler_type = scheduler_type
self.pipeline = self._load_pipeline(model_name, scheduler_type)
logger.info('--- done ---')
def _download_all_models(self):
for name in self.MODEL_NAMES:
self._load_pipeline(name, 'DDPM')
def generate(self, seed: int, num_steps: int) -> PIL.Image.Image:
logger.info('--- generate ---')
logger.info(f'{seed=}, {num_steps=}')
torch.manual_seed(seed)
if self.scheduler_type == 'DDPM':
res = self.pipeline(batch_size=1,
torch_device=self.device)['sample'][0]
elif self.scheduler_type in ['DDIM', 'PNDM']:
res = self.pipeline(batch_size=1,
torch_device=self.device,
num_inference_steps=num_steps)['sample'][0]
else:
raise ValueError
logger.info('--- done ---')
return res
def run(
self,
model_name: str,
scheduler_type: str,
num_steps: int,
seed: int,
) -> PIL.Image.Image:
self.set_pipeline(model_name, scheduler_type)
if scheduler_type == 'PNDM':
num_steps = max(4, min(num_steps, 100))
return self.generate(seed, num_steps)