File size: 4,103 Bytes
93428fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488e83c
 
 
 
93428fb
488e83c
 
 
 
 
 
 
93428fb
488e83c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb9e832
 
 
 
 
3c21e52
eb9e832
 
 
 
 
 
 
 
0170156
 
 
 
 
 
 
3c21e52
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import gc
import os
import random
import numpy as np
import json
import torch
from PIL import Image, PngImagePlugin
from datetime import datetime
from dataclasses import dataclass
from typing import Callable, Dict, Optional, Tuple
from diffusers import (
    DDIMScheduler,
    DPMSolverMultistepScheduler,
    DPMSolverSinglestepScheduler,
    EulerAncestralDiscreteScheduler,
    EulerDiscreteScheduler,
)

MAX_SEED = np.iinfo(np.int32).max

def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    return seed

def seed_everything(seed: int) -> torch.Generator:
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    generator = torch.Generator()
    generator.manual_seed(seed)
    return generator

def parse_aspect_ratio(aspect_ratio: str) -> Optional[Tuple[int, int]]:
    if aspect_ratio == "Custom":
        return None
    width, height = aspect_ratio.split(" x ")
    return int(width), int(height)

def aspect_ratio_handler(aspect_ratio: str, custom_width: int, custom_height: int) -> Tuple[int, int]:
    if aspect_ratio == "Custom":
        return custom_width, custom_height
    else:
        width, height = parse_aspect_ratio(aspect_ratio)
        return width, height

def get_scheduler(scheduler_config: Dict, name: str) -> Optional[Callable]:
    scheduler_factory_map = {
        "DPM++ 2M Karras": lambda: DPMSolverMultistepScheduler.from_config(scheduler_config, use_karras_sigmas=True),
        "DPM++ SDE Karras": lambda: DPMSolverSinglestepScheduler.from_config(scheduler_config, use_karras_sigmas=True),
        "DPM++ 2M SDE Karras": lambda: DPMSolverMultistepScheduler.from_config(scheduler_config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++"),
        "Euler": lambda: EulerDiscreteScheduler.from_config(scheduler_config),
        "Euler a": lambda: EulerAncestralDiscreteScheduler.from_config(scheduler_config),
        "DDIM": lambda: DDIMScheduler.from_config(scheduler_config),
    }
    return scheduler_factory_map.get(name, lambda: None)()

def free_memory() -> None:
    torch.cuda.empty_cache()
    gc.collect()

def common_upscale(samples: torch.Tensor, width: int, height: int, upscale_method: str) -> torch.Tensor:
    return torch.nn.functional.interpolate(samples, size=(height, width), mode=upscale_method)

def upscale(samples: torch.Tensor, upscale_method: str, scale_by: float) -> torch.Tensor:
    width = round(samples.shape[3] * scale_by)
    height = round(samples.shape[2] * scale_by)
    return common_upscale(samples, width, height, upscale_method)

def preprocess_image_dimensions(width, height):
    if width % 8 != 0:
        width = width - (width % 8)
    if height % 8 != 0:
        height = height - (height % 8)
    return width, height

def save_image(image, metadata, output_dir):
    current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
    os.makedirs(output_dir, exist_ok=True)
    filename = f"image_{current_time}.png"
    filepath = os.path.join(output_dir, filename)

    metadata_str = json.dumps(metadata)
    info = PngImagePlugin.PngInfo()
    info.add_text("metadata", metadata_str)
    image.save(filepath, "PNG", pnginfo=info)
    return filepath

def is_google_colab():
    try:
        import google.colab
        return True
    except:
        return False

def validate_json_parameters(json_str):
    try:
        params = json.loads(json_str)
        required_keys = ['prompt', 'negative_prompt', 'resolution', 'guidance_scale', 'num_inference_steps', 'seed', 'sampler']
        for key in required_keys:
            if key not in params:
                raise ValueError(f"Missing required key: {key}")
        return params
    except json.JSONDecodeError:
        raise ValueError("Invalid JSON format")
    except Exception as e:
        raise ValueError(f"Error parsing JSON: {str(e)}")

import base64
from io import BytesIO

def image_to_base64(image):
    buffered = BytesIO()
    image.save(buffered, format="PNG")
    return base64.b64encode(buffered.getvalue()).decode("utf-8")