|
import gradio as gr |
|
import torch |
|
from PIL import Image |
|
from torchvision import transforms |
|
from diffusers import StableDiffusionPipeline, StableDiffusionImageVariationPipeline, DiffusionPipeline |
|
import numpy as np |
|
import pandas as pd |
|
import math |
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
|
|
|
|
|
|
|
|
text_model_id = "runwayml/stable-diffusion-v1-5" |
|
|
|
model_id = "lambdalabs/sd-image-variations-diffusers" |
|
clip_model_id = "openai/clip-vit-large-patch14-336" |
|
|
|
max_tabs = 5 |
|
input_images = [None for i in range(max_tabs)] |
|
input_prompts = [None for i in range(max_tabs)] |
|
embedding_plots = [None for i in range(max_tabs)] |
|
|
|
embedding_base64s = [None for i in range(max_tabs)] |
|
|
|
|
|
|
|
def image_to_embedding(input_im): |
|
tform = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Resize( |
|
(224, 224), |
|
interpolation=transforms.InterpolationMode.BICUBIC, |
|
antialias=False, |
|
), |
|
transforms.Normalize( |
|
[0.48145466, 0.4578275, 0.40821073], |
|
[0.26862954, 0.26130258, 0.27577711]), |
|
]) |
|
|
|
inp = tform(input_im).to(device) |
|
dtype = next(pipe.image_encoder.parameters()).dtype |
|
image = inp.tile(1, 1, 1, 1).to(device=device, dtype=dtype) |
|
image_embeddings = pipe.image_encoder(image).image_embeds |
|
image_embeddings = image_embeddings[0] |
|
image_embeddings_np = image_embeddings.cpu().detach().numpy() |
|
return image_embeddings_np |
|
|
|
def prompt_to_embedding(prompt): |
|
|
|
inputs = processor(prompt, return_tensors="pt", padding='max_length', max_length=77) |
|
|
|
|
|
prompt_tokens = inputs.input_ids |
|
|
|
with torch.no_grad(): |
|
prompt_embededdings = model.get_text_features(prompt_tokens.to(device)) |
|
prompt_embededdings = prompt_embededdings[0].cpu().detach().numpy() |
|
return prompt_embededdings |
|
|
|
def embedding_to_image(embeddings): |
|
size = math.ceil(math.sqrt(embeddings.shape[0])) |
|
image_embeddings_square = np.pad(embeddings, (0, size**2 - embeddings.shape[0]), 'constant') |
|
image_embeddings_square.resize(size,size) |
|
embedding_image = Image.fromarray(image_embeddings_square, mode="L") |
|
return embedding_image |
|
|
|
def embedding_to_base64(embeddings): |
|
import base64 |
|
|
|
embeddings = embeddings.astype(np.float16) |
|
embeddings_b64 = base64.urlsafe_b64encode(embeddings).decode() |
|
return embeddings_b64 |
|
|
|
def base64_to_embedding(embeddings_b64): |
|
import base64 |
|
embeddings = base64.urlsafe_b64decode(embeddings_b64) |
|
embeddings = np.frombuffer(embeddings, dtype=np.float16) |
|
|
|
return embeddings |
|
|
|
def main( |
|
|
|
embeddings, |
|
scale=3.0, |
|
n_samples=4, |
|
steps=25, |
|
seed=0 |
|
): |
|
|
|
if seed == None: |
|
seed = np.random.randint(2147483647) |
|
|
|
generator = torch.Generator().manual_seed(int(seed)) |
|
|
|
embeddings = base64_to_embedding(embeddings) |
|
embeddings = torch.tensor(embeddings).to(device) |
|
|
|
images_list = pipe( |
|
|
|
|
|
embeddings, |
|
guidance_scale=scale, |
|
num_inference_steps=steps, |
|
generator=generator, |
|
) |
|
|
|
images = [] |
|
for i, image in enumerate(images_list["images"]): |
|
images.append(image) |
|
|
|
return images |
|
|
|
def on_image_load_update_embeddings(image_data): |
|
|
|
if image_data is None: |
|
embeddings = prompt_to_embedding('') |
|
embeddings_b64 = embedding_to_base64(embeddings) |
|
return gr.Text.update(embeddings_b64) |
|
embeddings = image_to_embedding(image_data) |
|
embeddings_b64 = embedding_to_base64(embeddings) |
|
return gr.Text.update(embeddings_b64) |
|
|
|
def on_prompt_change_update_embeddings(prompt): |
|
|
|
if prompt is None or prompt == "": |
|
embeddings = prompt_to_embedding('') |
|
embeddings_b64 = embedding_to_base64(embeddings) |
|
return gr.Text.update(embedding_to_base64(embeddings)) |
|
embeddings = prompt_to_embedding(prompt) |
|
embeddings_b64 = embedding_to_base64(embeddings) |
|
return gr.Text.update(embeddings_b64) |
|
|
|
|
|
|
|
def on_embeddings_changed_update_average_embeddings(embedding_base64s_state, embedding_base64, idx): |
|
|
|
final_embedding = None |
|
num_embeddings = 0 |
|
embedding_base64s_state[idx] = embedding_base64 |
|
|
|
|
|
for embedding_base64 in embedding_base64s_state: |
|
if embedding_base64 is None or embedding_base64 == "": |
|
continue |
|
embedding = base64_to_embedding(embedding_base64) |
|
if final_embedding is None: |
|
final_embedding = embedding |
|
else: |
|
final_embedding = final_embedding + embedding |
|
num_embeddings += 1 |
|
if final_embedding is None: |
|
embeddings = prompt_to_embedding('') |
|
embeddings_b64 = embedding_to_base64(embeddings) |
|
return gr.Text.update(embeddings_b64) |
|
final_embedding = final_embedding / num_embeddings |
|
embeddings_b64 = embedding_to_base64(final_embedding) |
|
return gr.Text.update(embeddings_b64) |
|
|
|
def on_embeddings_changed_update_plot(embeddings_b64): |
|
|
|
if embeddings_b64 is None or embeddings_b64 == "": |
|
return gr.LinePlot.update() |
|
|
|
embeddings = base64_to_embedding(embeddings_b64) |
|
data = pd.DataFrame({ |
|
'embedding': embeddings, |
|
'index': [n for n in range(len(embeddings))]}) |
|
return gr.LinePlot.update(data, |
|
x="index", |
|
y="embedding", |
|
|
|
title="Embeddings", |
|
|
|
|
|
tooltip=['index', 'embedding'], |
|
|
|
|
|
width=embeddings.shape[0]) |
|
|
|
|
|
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda:0" if torch.cuda.is_available() else "cpu") |
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
model_id, |
|
custom_pipeline="pipeline.py", |
|
torch_dtype=torch.float16, |
|
|
|
requires_safety_checker = False, safety_checker=None, |
|
text_encoder = CLIPTextModel, |
|
tokenizer = CLIPTokenizer, |
|
) |
|
pipe = pipe.to(device) |
|
|
|
from transformers import AutoProcessor, AutoModel |
|
processor = AutoProcessor.from_pretrained(clip_model_id) |
|
model = AutoModel.from_pretrained(clip_model_id) |
|
model = model.to(device) |
|
|
|
examples = [ |
|
["frog.png", 3, 1, 25, 0], |
|
["img0.jpg", 3, 1, 25, 0], |
|
["img1.jpg", 3, 1, 25, 0], |
|
["img2.jpg", 3, 1, 25, 0], |
|
["img3.jpg", 3, 1, 25, 0], |
|
] |
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
for i in range(max_tabs): |
|
with gr.Tab(f"Input {i}"): |
|
with gr.Row(): |
|
with gr.Column(scale=1, min_width=240): |
|
input_images[i] = gr.Image() |
|
with gr.Column(scale=3, min_width=600): |
|
embedding_plots[i] = gr.LinePlot(show_label=False).style(container=False) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1, min_width=240): |
|
input_prompts[i] = gr.Textbox() |
|
with gr.Column(scale=3, min_width=600): |
|
with gr.Accordion("Embeddings", open=False): |
|
embedding_base64s[i] = gr.Textbox(show_label=False) |
|
|
|
with gr.Row(): |
|
average_embedding_plot = gr.LinePlot(show_label=False).style(container=False) |
|
with gr.Row(): |
|
average_embedding_base64 = gr.Textbox(show_label=False) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1, min_width=200): |
|
scale = gr.Slider(0, 25, value=3, step=1, label="Guidance scale") |
|
with gr.Column(scale=1, min_width=200): |
|
n_samples = gr.Slider(1, 4, value=1, step=1, label="Number images") |
|
with gr.Column(scale=1, min_width=200): |
|
steps = gr.Slider(5, 50, value=25, step=5, label="Steps") |
|
with gr.Column(scale=1, min_width=200): |
|
seed = gr.Number(None, label="Seed", precision=0) |
|
with gr.Row(): |
|
submit = gr.Button("Submit") |
|
with gr.Row(): |
|
output = gr.Gallery(label="Generated variations") |
|
|
|
embedding_base64s_state = gr.State(value=[None for i in range(max_tabs)]) |
|
for i in range(max_tabs): |
|
input_images[i].change(on_image_load_update_embeddings, input_images[i], [embedding_base64s[i]]) |
|
input_prompts[i].submit(on_prompt_change_update_embeddings, input_prompts[i], [embedding_base64s[i]]) |
|
embedding_base64s[i].change(on_embeddings_changed_update_plot, embedding_base64s[i], [embedding_plots[i]]) |
|
|
|
|
|
idx_state = gr.State(value=i) |
|
embedding_base64s[i].change(on_embeddings_changed_update_average_embeddings, [embedding_base64s_state, embedding_base64s[i], idx_state], average_embedding_base64) |
|
|
|
average_embedding_base64.change(on_embeddings_changed_update_plot, average_embedding_base64, average_embedding_plot) |
|
|
|
|
|
submit.click(main, inputs= [average_embedding_base64, scale, n_samples, steps, seed], outputs=output) |
|
output.style(grid=2) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |