Spaces:
Sleeping
Sleeping
from lib.get_model import get_model, device | |
from lib.sampling import edm_sampler | |
from lib.embedding import extract_features | |
from lib.encoders import StabilityVAEEncoder | |
from lib.cond_gen import get_vae_decoder | |
from safetensors.torch import load_file | |
from torchvision.transforms import ToPILImage | |
import torch | |
import gradio as gr | |
import json | |
torch.set_grad_enabled(False) | |
net = get_model() | |
net.load_state_dict(load_file("model_weights/1girl-edm-xs-test-1.safetensors")) | |
cond_gen = get_vae_decoder().to(device) | |
cond_gen.load_state_dict(load_file("model_weights/condgen_vae_decoder.safetensors")) | |
stability_encoder = StabilityVAEEncoder() | |
def guided(net, scale=1): | |
def f(x, t, label): | |
if scale == 1: | |
return net(x, t, label) | |
return torch.lerp(net(x, t, net.uncond_emb), net(x, t, label), float(scale)) | |
return f | |
def generate_image(label, guidance_scale, n_steps, seed): | |
label = torch.tensor(label)[None].to(device) | |
gen = torch.Generator(device).manual_seed(seed) | |
x = torch.randn((1, 4, 88, 64), device=device, generator=gen) | |
randn_like = lambda *a, **ka: torch.zeros_like(*a, **ka).normal_(generator=gen) | |
im = edm_sampler( | |
guided(net, guidance_scale), x, label, num_steps=n_steps, randn_like=randn_like | |
) | |
im = stability_encoder.decode(im) | |
return ToPILImage()(im[0]) | |
with gr.Blocks() as demo: | |
selected = [0] | |
with gr.Row(): | |
gr.Markdown( | |
"""# 1girl-EDM2-XS-test-1 Demo | |
Demo of a 125M param model trained in 1 GPU-day for generating `1girl solo` images. | |
For those who don't want to wait 10 minutes for an image, a [Colab demo](https://colab.research.google.com/drive/1aBopomQ6wLJtQUTZ1CZn2i-zGpyLGWKM?usp=sharing) is also available. | |
[More info](https://huggingface.co/gustproof/1girl-EDM2-XS-test-1) | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Group(): | |
btn = gr.Button("Generate", variant="primary") | |
guidance = gr.Slider(1, 15, 5, step=0.1, label="Guidance scale") | |
n_steps = gr.Slider(2, 35, 24, step=1, label="Inference steps") | |
seed = gr.Slider( | |
-1, 2147483647, -1, step=1, label="Random seed (-1: randomize)" | |
) | |
with gr.Tab("Condition: auto") as auto_tab: | |
gr.Markdown("Conditioning is generated with an external model") | |
with gr.Tab("Condition: from image") as img_tab: | |
gr.Markdown( | |
"Conditioning is extracted from the image with a [tagger](https://huggingface.co/SmilingWolf/wd-eva02-large-tagger-v3). " | |
"This is works like [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter) and not like [img2img](https://github.com/CompVis/stable-diffusion/blob/main/scripts/img2img.py)" | |
) | |
ref_im = gr.Image(label="Reference image", type="pil") | |
with gr.Tab("Condition: precomputed") as txt_tab: | |
gr.Markdown("Use a precomputed 1024D vector a the condition") | |
ref_txt = gr.TextArea( | |
label="Precomputed conditioning", | |
placeholder="Copy & Paste from the output", | |
) | |
with gr.Column(): | |
out_im = gr.Image(label="Generated Image", show_download_button=True) | |
out_seed = gr.Textbox(label="Seed", show_copy_button=True) | |
out_emb = gr.TextArea(label="Condition vector", show_copy_button=True) | |
def get_label(tab_index, cond_img=None, cond_txt=None): | |
if tab_index == 0: | |
return cond_gen(torch.randn((1, 512), device=device))[0].detach().cpu() | |
if tab_index == 1: | |
return extract_features(cond_img, device) | |
return torch.tensor(json.loads(cond_txt)) | |
def on_select(e: gr.SelectData): | |
selected[0] = e.index | |
for t in [auto_tab, img_tab, txt_tab]: | |
t.select(on_select) | |
def main(guidance, n_steps, seed, cond_img=None, cond_txt=None): | |
if seed < 0: | |
seed = torch.randint(0, 2147483647, ()).item() | |
label = get_label(selected[0], cond_img, cond_txt) | |
im = generate_image(label, guidance, n_steps, seed) | |
label_txt = json.dumps(label.numpy().astype(float).round(3).tolist()) | |
return im, seed, label_txt | |
btn.click( | |
main, [guidance, n_steps, seed, ref_im, ref_txt], [out_im, out_seed, out_emb] | |
) | |
demo.launch() | |