#@title Prepare the Concepts Library to be used
import requests
import os
import gradio as gr
import wget
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline
from huggingface_hub import HfApi
from transformers import CLIPTextModel, CLIPTokenizer
import html
api = HfApi()
models_list = api.list_models(author="sd-concepts-library", sort="likes", direction=-1)
models = []
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True, revision="fp16", torch_dtype=torch.float16).to("cuda")
def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None):
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
# separate token and the embeds
trained_token = list(loaded_learned_embeds.keys())[0]
embeds = loaded_learned_embeds[trained_token]
# cast to dtype of text_encoder
dtype = text_encoder.get_input_embeddings().weight.dtype
embeds.to(dtype)
# add the token in tokenizer
token = token if token is not None else trained_token
num_added_tokens = tokenizer.add_tokens(token)
i = 1
while(num_added_tokens == 0):
print(f"The tokenizer already contains the token {token}.")
token = f"{token[:-1]}-{i}>"
print(f"Attempting to add the token {token}.")
num_added_tokens = tokenizer.add_tokens(token)
i+=1
# resize the token embeddings
text_encoder.resize_token_embeddings(len(tokenizer))
# get the id for the token and assign the embeds
token_id = tokenizer.convert_tokens_to_ids(token)
text_encoder.get_input_embeddings().weight.data[token_id] = embeds
return token
print("Setting up the public library")
for model in models_list:
model_content = {}
model_id = model.modelId
model_content["id"] = model_id
embeds_url = f"https://huggingface.co/{model_id}/resolve/main/learned_embeds.bin"
os.makedirs(model_id,exist_ok = True)
if not os.path.exists(f"{model_id}/learned_embeds.bin"):
try:
wget.download(embeds_url, out=model_id)
except:
continue
token_identifier = f"https://huggingface.co/{model_id}/raw/main/token_identifier.txt"
response = requests.get(token_identifier)
token_name = response.text
concept_type = f"https://huggingface.co/{model_id}/raw/main/type_of_concept.txt"
response = requests.get(concept_type)
concept_name = response.text
model_content["concept_type"] = concept_name
images = []
for i in range(4):
url = f"https://huggingface.co/{model_id}/resolve/main/concept_images/{i}.jpeg"
image_download = requests.get(url)
url_code = image_download.status_code
if(url_code == 200):
file = open(f"{model_id}/{i}.jpeg", "wb") ## Creates the file for image
file.write(image_download.content) ## Saves file content
file.close()
images.append(f"{model_id}/{i}.jpeg")
model_content["images"] = images
learned_token = load_learned_embed_in_clip(f"{model_id}/learned_embeds.bin", pipe.text_encoder, pipe.tokenizer, token_name)
model_content["token"] = learned_token
models.append(model_content)
#@title Run the app to navigate around [the Library](https://huggingface.co/sd-concepts-library)
#@markdown Click the `Running on public URL:` result to run the Gradio app
SELECT_LABEL = "Select concept"
def assembleHTML(model):
html_gallery = ''
html_gallery = html_gallery+'''
'''
for model in models:
html_gallery = html_gallery+f'''
Navigate through community created concepts and styles via Stable Diffusion Textual Inversion and pick yours for inference.
To train your own concepts and contribute to the library check out this notebook.
''')
with gr.Row():
with gr.Column():
gr.Markdown(f"### Navigate {len(models)}+ Textual-Inversion community trained concepts")
with gr.Row():
image_blocks = []
#for i, model in enumerate(models):
with gr.Box().style(border=None):
gr.HTML(assembleHTML(models))
#title_block(model["token"], model["id"])
#image_blocks.append(image_block(model["images"], model["concept_type"]))
with gr.Box():
with gr.Row(elem_id="prompt_area").style(mobile_collapse=False, equal_height=True):
text = gr.Textbox(
label="Enter your prompt", placeholder="Enter your prompt", show_label=False, max_lines=1, elem_id="prompt_input"
).style(
border=(True, False, True, True),
rounded=(True, False, False, True),
container=False
)
btn = gr.Button("Run",elem_id="run_btn").style(
margin=False,
rounded=(False, True, True, False)
)
with gr.Row().style():
infer_outputs = gr.Gallery(show_label=False).style(grid=[2], height="512px")
with gr.Row():
gr.HTML("
Prompting may not work as you are used to. objects may need the concept added at the end, styles may work better at the beginning. You can navigate on lexica.art to get inspired on prompts