Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import shlex | |
import subprocess | |
from pathlib import Path | |
from tempfile import TemporaryDirectory | |
from textwrap import dedent | |
import numpy as np | |
import streamlit as st | |
import torch | |
from PIL import Image | |
from transformers import CLIPTokenizer | |
def hex_to_rgb(s: str) -> tuple[int, int, int]: | |
value = s.lstrip("#") | |
return (int(value[:2], 16), int(value[2:4], 16), int(value[4:6], 16)) | |
st.header("Color Textual Inversion") | |
with st.expander(label="info"): | |
with open("info.txt", "r", encoding="utf-8") as f: | |
st.markdown(f.read()) | |
duplicate_button = """<a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/Bingsu/color_textual_inversion?duplicate=true"><img style="margin: 0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a>""" | |
st.markdown(duplicate_button, unsafe_allow_html=True) | |
col1, col2 = st.columns([15, 85]) | |
color = col1.color_picker("Pick a color", "#00f900") | |
col2.text_input("", color, disabled=True) | |
emb_name = st.text_input("Embedding name", color.lstrip("#").upper()) | |
init_token = st.text_input("Initializer token", "init token name") | |
rgb = hex_to_rgb(color) | |
img_array = np.zeros((128, 128, 3), dtype=np.uint8) | |
for i in range(3): | |
img_array[..., i] = rgb[i] | |
dataset_temp = TemporaryDirectory(prefix="dataset_", dir=".") | |
dataset_path = Path(dataset_temp.name) | |
output_temp = TemporaryDirectory(prefix="output_", dir=".") | |
output_path = Path(output_temp.name) | |
img_path = dataset_path / f"{emb_name}.png" | |
Image.fromarray(img_array).save(img_path) | |
with st.sidebar: | |
model_name = st.text_input("Model name", "Linaqruf/anything-v3.0") | |
steps = st.slider("Steps", 1, 100, value=1, step=1) | |
learning_rate = st.text_input("Learning rate", "0.001") | |
learning_rate = float(learning_rate) | |
tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer") | |
# case 1: init_token is not a single token | |
token = tokenizer.tokenize(init_token) | |
if len(token) > 1: | |
st.warning("Initializer token must be a single token") | |
st.stop() | |
# case 2: init_token already exists in the tokenizer | |
num_added_tokens = tokenizer.add_tokens(emb_name) | |
if num_added_tokens == 0: | |
st.warning(f"The tokenizer already contains the token {emb_name}") | |
st.stop() | |
cmd = """ | |
accelerate launch textual_inversion.py \ | |
--pretrained_model_name_or_path={model_name} \ | |
--train_data_dir={dataset_path} \ | |
--learnable_property="style" \ | |
--placeholder_token="{emb_name}" \ | |
--initializer_token="{init}" \ | |
--resolution=128 \ | |
--train_batch_size=1 \ | |
--repeats=1 \ | |
--gradient_accumulation_steps=1 \ | |
--max_train_steps={steps} \ | |
--learning_rate={lr} \ | |
--output_dir={output_path} \ | |
--only_save_embeds | |
""".strip() | |
cmd = dedent(cmd).format( | |
model_name=model_name, | |
dataset_path=dataset_path.as_posix(), | |
emb_name=emb_name, | |
init=init_token, | |
steps=steps, | |
lr=learning_rate, | |
output_path=output_path.as_posix(), | |
) | |
cmd = shlex.split(cmd) | |
result_path = output_path / "learned_embeds.bin" | |
captured = "" | |
start_button = st.button("Start") | |
download_button = st.empty() | |
if start_button: | |
with st.spinner("Training..."): | |
placeholder = st.empty() | |
p = subprocess.Popen( | |
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8" | |
) | |
while line := p.stderr.readline(): | |
captured += line | |
placeholder.code(captured, language="bash") | |
if not result_path.exists(): | |
st.stop() | |
# fix unknown file volume bug | |
trained_emb = torch.load(result_path, map_location="cpu") | |
for k, v in trained_emb.items(): | |
trained_emb[k] = torch.from_numpy(v.numpy()) | |
torch.save(trained_emb, result_path) | |
file = result_path.read_bytes() | |
download_button.download_button(f"Download {emb_name}.pt", file, f"{emb_name}.pt") | |
st.download_button(f"Download {emb_name}.pt ", file, f"{emb_name}.pt") | |
dataset_temp.cleanup() | |
output_temp.cleanup() | |