Spaces:
Build error
Build error
File size: 2,405 Bytes
9042918 1fa5d2c 9042918 114e9fc bc96fce 71da51f 9042918 8880ecb 9042918 8880ecb 9042918 8880ecb 9042918 114e9fc 71da51f 9042918 8880ecb 9042918 8880ecb 1fa5d2c 7b6145e 9042918 8880ecb 9042918 1fa5d2c 9042918 7b6145e 8880ecb 9042918 1fa5d2c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
from __future__ import annotations
import gc
import pathlib
import sys
import gradio as gr
import PIL.Image
import numpy as np
import torch
from diffusers import StableDiffusionPipeline
sys.path.insert(0, './custom-diffusion')
class InferencePipeline:
def __init__(self):
self.pipe = None
self.device = torch.device(
'cuda:0' if torch.cuda.is_available() else 'cpu')
self.weight_path = None
def clear(self) -> None:
self.weight_path = None
del self.pipe
self.pipe = None
torch.cuda.empty_cache()
gc.collect()
@staticmethod
def get_weight_path(name: str) -> pathlib.Path:
curr_dir = pathlib.Path(__file__).parent
return curr_dir / name
def load_pipe(self, model_id: str, filename: str) -> None:
weight_path = self.get_weight_path(filename)
if weight_path == self.weight_path:
return
self.weight_path = weight_path
weight = torch.load(self.weight_path, map_location=self.device)
if self.device.type == 'cpu':
pipe = StableDiffusionPipeline.from_pretrained(model_id)
else:
pipe = StableDiffusionPipeline.from_pretrained(
model_id, torch_dtype=torch.float16)
pipe = pipe.to(self.device)
from src import diffuser_training
diffuser_training.load_model(pipe.text_encoder, pipe.tokenizer, pipe.unet, weight_path, '<new1>')
self.pipe = pipe
def run(
self,
base_model: str,
weight_name: str,
prompt: str,
seed: int,
n_steps: int,
guidance_scale: float,
eta: float,
batch_size: int,
resolution: int,
) -> PIL.Image.Image:
if not torch.cuda.is_available():
raise gr.Error('CUDA is not available.')
self.load_pipe(base_model, weight_name)
generator = torch.Generator(device=self.device).manual_seed(seed)
out = self.pipe([prompt]*batch_size,
num_inference_steps=n_steps,
guidance_scale=guidance_scale,
height=resolution, width=resolution,
eta = eta,
generator=generator) # type: ignore
out = out.images
out = PIL.Image.fromarray(np.hstack([np.array(x) for x in out]))
return out
|