stable-diffusion-demo / src /inference.py
Prgckwb's picture
:tada: change some process
6850d81
import gradio as gr
import spaces
import torch
from PIL import Image
from compel import Compel, DiffusersTextualInversionManager
from diffusers import DiffusionPipeline, StableDiffusionPipeline
from diffusers.utils import make_image_grid
from src.const import DIFFUSERS_MODEL_IDS, EXTERNAL_MODEL_MAPPING, DEVICE
def load_pipeline(model_id, use_model_offload, safety_checker):
# Diffusers リポジトリ内のモデル
if model_id in DIFFUSERS_MODEL_IDS:
pipe = DiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16,
)
# CIVITAI 系列由来のモデル
else:
pipe = DiffusionPipeline.from_pretrained(
EXTERNAL_MODEL_MAPPING[model_id],
torch_dtype=torch.float16,
)
# Load Textual Inversion
pipe.load_textual_inversion("checkpoints/embeddings/BadNegAnatomyV1 neg.pt", token='BadNegAnatomyV1-neg')
pipe.load_textual_inversion("checkpoints/embeddings/Deep Negative V1 75T.pt", token='DeepNegative')
pipe.load_textual_inversion("checkpoints/embeddings/easynegative.safetensors", token='EasyNegative')
pipe.load_textual_inversion("checkpoints/embeddings/Negative Hand Embedding.pt", token='negative_hand-neg')
# Load LoRA
pipe.load_lora_weights("checkpoints/lora/detailed style SD1.5.safetensors", adapter_name='detail')
pipe.load_lora_weights("checkpoints/lora/perfection style SD1.5.safetensors", adapter_name='perfection')
pipe.load_lora_weights("checkpoints/lora/Hand v3 SD1.5.safetensors", adapter_name='hands')
pipe.set_adapters(['detail', 'hands'], adapter_weights=[0.5, 0.5])
# VRAM が少ないとき用の対策
if use_model_offload:
pipe.enable_model_cpu_offload()
else:
pipe = pipe.to(DEVICE)
if not safety_checker:
pipe.safety_checker = None
return pipe
@spaces.GPU(duration=120)
@torch.inference_mode()
def inference(
prompt: str,
model_id: str = "stabilityai/stable-diffusion-3-medium-diffusers",
negative_prompt: str = "",
width: int = 512,
height: int = 512,
guidance_scale: float = 7.5,
num_inference_steps: int = 50,
num_images: int = 4,
safety_checker: bool = True,
use_model_offload: bool = False,
seed: int = 8888,
progress=gr.Progress(track_tqdm=True),
) -> Image.Image:
progress(0, 'Loading pipeline...')
pipe = load_pipeline(model_id, use_model_offload, safety_checker)
# Seed 固定
generator = torch.Generator(device=DEVICE).manual_seed(seed)
if isinstance(pipe, StableDiffusionPipeline):
# For Compel
textual_inversion_manager = DiffusersTextualInversionManager(pipe)
compel_procs = Compel(
tokenizer=pipe.tokenizer,
text_encoder=pipe.text_encoder,
textual_inversion_manager=textual_inversion_manager,
truncate_long_prompts=False,
)
prompt_embed = compel_procs(prompt)
negative_prompt_embed = compel_procs(negative_prompt)
prompt_embed, negative_prompt_embed = compel_procs.pad_conditioning_tensors_to_same_length(
[prompt_embed, negative_prompt_embed]
)
progress(0.3, 'Generating images...')
images = pipe(
prompt_embeds=prompt_embed,
negative_prompt_embeds=negative_prompt_embed,
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
num_images_per_prompt=num_images,
generator=generator,
).images
else:
progress(0.3, 'Generating images...')
images = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
num_images_per_prompt=num_images,
generator=generator,
).images
progress(0.9, f'Done generating {num_images} images')
if num_images % 2 == 1:
image = make_image_grid(images, rows=num_images, cols=1)
else:
image = make_image_grid(images, rows=2, cols=num_images // 2)
return image