File size: 3,864 Bytes
4fb3c5e
 
 
 
 
 
 
 
6f3a230
 
4fb3c5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f3a230
4fb3c5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f3a230
4fb3c5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from __future__ import annotations

import logging
import os
import sys

import PIL.Image
import torch
from diffusers import (DDIMPipeline, DDIMScheduler, DDPMPipeline,
                       DiffusionPipeline, PNDMPipeline, PNDMScheduler)

HF_TOKEN = os.environ['HF_TOKEN']

formatter = logging.Formatter(
    '[%(asctime)s] %(name)s %(levelname)s: %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S')
stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setLevel(logging.INFO)
stream_handler.setFormatter(formatter)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.propagate = False
logger.addHandler(stream_handler)


class Model:

    MODEL_NAMES = [
        'ddpm-128-exp000',
    ]

    def __init__(self, device: str | torch.device):
        self.device = torch.device(device)
        self._download_all_models()

        self.model_name = self.MODEL_NAMES[0]
        self.scheduler_type = 'DDIM'
        self.pipeline = self._load_pipeline(self.model_name,
                                            self.scheduler_type)

    def _load_pipeline(self, model_name: str,
                       scheduler_type: str) -> DiffusionPipeline:
        repo_id = f'hysts/diffusers-anime-faces-{model_name}'
        if scheduler_type == 'DDPM':
            pipeline = DDPMPipeline.from_pretrained(repo_id,
                                                    use_auth_token=HF_TOKEN)
        elif scheduler_type == 'DDIM':
            pipeline = DDIMPipeline.from_pretrained(repo_id,
                                                    use_auth_token=HF_TOKEN)
            config, _ = DDIMScheduler.extract_init_dict(
                dict(pipeline.scheduler.config))
            pipeline.scheduler = DDIMScheduler(**config)
        elif scheduler_type == 'PNDM':
            pipeline = PNDMPipeline.from_pretrained(repo_id,
                                                    use_auth_token=HF_TOKEN)
            config, _ = PNDMScheduler.extract_init_dict(
                dict(pipeline.scheduler.config))
            pipeline.scheduler = PNDMScheduler(**config)
        else:
            raise ValueError
        return pipeline

    def set_pipeline(self, model_name: str, scheduler_type: str) -> None:
        logger.info('--- set_pipeline ---')
        logger.info(f'{model_name=}, {scheduler_type=}')

        if model_name == self.model_name and scheduler_type == self.scheduler_type:
            logger.info('Skipping')
            logger.info('--- done ---')
            return
        self.model_name = model_name
        self.scheduler_type = scheduler_type
        self.pipeline = self._load_pipeline(model_name, scheduler_type)

        logger.info('--- done ---')

    def _download_all_models(self) -> None:
        for name in self.MODEL_NAMES:
            self._load_pipeline(name, 'DDPM')

    def generate(self, seed: int, num_steps: int) -> PIL.Image.Image:
        logger.info('--- generate ---')
        logger.info(f'{seed=}, {num_steps=}')

        torch.manual_seed(seed)
        if self.scheduler_type == 'DDPM':
            res = self.pipeline(batch_size=1,
                                torch_device=self.device)['sample'][0]
        elif self.scheduler_type in ['DDIM', 'PNDM']:
            res = self.pipeline(batch_size=1,
                                torch_device=self.device,
                                num_inference_steps=num_steps)['sample'][0]
        else:
            raise ValueError

        logger.info('--- done ---')
        return res

    def run(
        self,
        model_name: str,
        scheduler_type: str,
        num_steps: int,
        seed: int,
    ) -> PIL.Image.Image:
        self.set_pipeline(model_name, scheduler_type)
        if scheduler_type == 'PNDM':
            num_steps = max(4, min(num_steps, 100))
        return self.generate(seed, num_steps)