Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,429 Bytes
126e15c 2000056 8321e61 2000056 8062689 2000056 3b85c22 2000056 08e1345 2000056 8062689 2000056 8321e61 b97a53d 8321e61 2000056 8321e61 2000056 8321e61 2000056 8321e61 2000056 8321e61 546a277 2000056 8321e61 2000056 546a277 2000056 8321e61 2000056 8321e61 2000056 8321e61 2000056 8321e61 2000056 8321e61 2000056 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import spaces
import argparse
import os
import time
from os import path
from PIL import ImageOps
cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
os.environ["TRANSFORMERS_CACHE"] = cache_path
os.environ["HF_HUB_CACHE"] = cache_path
os.environ["HF_HOME"] = cache_path
import gradio as gr
import torch
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
from scheduling_tcd import TCDScheduler
torch.backends.cuda.matmul.allow_tf32 = True
js_func = """
function refresh() {
const url = new URL(window.location);
if (url.searchParams.get('__theme') !== 'dark') {
url.searchParams.set('__theme', 'dark');
window.location.href = url.href;
}
}
"""
class timer:
def __init__(self, method_name="timed process"):
self.method = method_name
def __enter__(self):
self.start = time.time()
print(f"{self.method} starts")
def __exit__(self, exc_type, exc_val, exc_tb):
end = time.time()
print(f"{self.method} took {str(round(end - self.start, 2))}s")
if not path.exists(cache_path):
os.makedirs(cache_path, exist_ok=True)
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-scribble", torch_dtype=torch.float16, use_safetensors=True)
pipe = StableDiffusionControlNetPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16, variant="fp16")
pipe.load_lora_weights("ByteDance/Hyper-SD", weight_name="Hyper-SD15-1step-lora.safetensors", adapter_name="default")
pipe.to("cuda")
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config, timestep_spacing ="trailing")
with gr.Blocks(js=js_func) as demo:
with gr.Column():
with gr.Row():
with gr.Column():
# scribble = gr.Image(source="canvas", tool="color-sketch", shape=(512, 512), height=768, width=768, type="pil")
scribble = gr.ImageEditor(type="pil", image_mode="L", crop_size=(512, 512), sources=(), brush=gr.Brush(color_mode="fixed", colors=["#FFFFFF"]), canvas_size=(512, 512))
# scribble_out = gr.Image(height=384, width=384)
num_images = gr.Slider(label="Number of Images", minimum=1, maximum=8, step=1, value=4, interactive=True)
steps = gr.Slider(label="Inference Steps", minimum=1, maximum=8, step=1, value=1, interactive=True)
prompt = gr.Text(label="Prompt", value="a photo of a cat", interactive=True)
eta = gr.Number(label="Eta (Corresponds to parameter eta (η) in the DDIM paper, i.e. 0.0 eqauls DDIM, 1.0 equals LCM)", value=1., interactive=True)
controlnet_scale = gr.Number(label="ControlNet Conditioning Scale", value=1.0, interactive=True)
seed = gr.Number(label="Seed", value=3413, interactive=True)
btn = gr.Button(value="run")
with gr.Column():
output = gr.Gallery(height=768, format="png")
# output = gr.Image()
@spaces.GPU
def process_image(steps, prompt, controlnet_scale, eta, seed, scribble, num_images):
global pipe
if scribble:
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16), timer("inference"):
result = pipe(
prompt=[prompt]*num_images,
image=[ImageOps.invert(scribble['composite'])]*num_images,
# image=[scribble['composite']]*num_images,
generator=torch.Generator().manual_seed(int(seed)),
num_inference_steps=steps,
guidance_scale=0.,
eta=eta,
controlnet_conditioning_scale=float(controlnet_scale),
).images
# result[0].save("test.jpg")
# print(result[0])
return result
else:
return None
reactive_controls = [steps, prompt, controlnet_scale, eta, seed, scribble, num_images]
for control in reactive_controls:
if reactive_controls[-2] is not None:
control.change(fn=process_image, inputs=reactive_controls, outputs=[output, ])
btn.click(process_image, inputs=reactive_controls, outputs=[output, ])
if __name__ == "__main__":
demo.launch() |