Bingsu's picture
fix: change default value
72cb885
from __future__ import annotations
import shlex
import subprocess
from pathlib import Path
from tempfile import TemporaryDirectory
from textwrap import dedent
import numpy as np
import streamlit as st
import torch
from PIL import Image
from transformers import CLIPTokenizer
def hex_to_rgb(s: str) -> tuple[int, int, int]:
value = s.lstrip("#")
return (int(value[:2], 16), int(value[2:4], 16), int(value[4:6], 16))
st.header("Color Textual Inversion")
with st.expander(label="info"):
with open("info.txt", "r", encoding="utf-8") as f:
st.markdown(f.read())
duplicate_button = """<a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/Bingsu/color_textual_inversion?duplicate=true"><img style="margin: 0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a>"""
st.markdown(duplicate_button, unsafe_allow_html=True)
col1, col2 = st.columns([15, 85])
color = col1.color_picker("Pick a color", "#00f900")
col2.text_input("", color, disabled=True)
emb_name = st.text_input("Embedding name", color.lstrip("#").upper())
init_token = st.text_input("Initializer token", "init token name")
rgb = hex_to_rgb(color)
img_array = np.zeros((128, 128, 3), dtype=np.uint8)
for i in range(3):
img_array[..., i] = rgb[i]
dataset_temp = TemporaryDirectory(prefix="dataset_", dir=".")
dataset_path = Path(dataset_temp.name)
output_temp = TemporaryDirectory(prefix="output_", dir=".")
output_path = Path(output_temp.name)
img_path = dataset_path / f"{emb_name}.png"
Image.fromarray(img_array).save(img_path)
with st.sidebar:
model_name = st.text_input("Model name", "Linaqruf/anything-v3.0")
steps = st.slider("Steps", 1, 100, value=1, step=1)
learning_rate = st.text_input("Learning rate", "0.001")
learning_rate = float(learning_rate)
tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
# case 1: init_token is not a single token
token = tokenizer.tokenize(init_token)
if len(token) > 1:
st.warning("Initializer token must be a single token")
st.stop()
# case 2: init_token already exists in the tokenizer
num_added_tokens = tokenizer.add_tokens(emb_name)
if num_added_tokens == 0:
st.warning(f"The tokenizer already contains the token {emb_name}")
st.stop()
cmd = """
accelerate launch textual_inversion.py \
--pretrained_model_name_or_path={model_name} \
--train_data_dir={dataset_path} \
--learnable_property="style" \
--placeholder_token="{emb_name}" \
--initializer_token="{init}" \
--resolution=128 \
--train_batch_size=1 \
--repeats=1 \
--gradient_accumulation_steps=1 \
--max_train_steps={steps} \
--learning_rate={lr} \
--output_dir={output_path} \
--only_save_embeds
""".strip()
cmd = dedent(cmd).format(
model_name=model_name,
dataset_path=dataset_path.as_posix(),
emb_name=emb_name,
init=init_token,
steps=steps,
lr=learning_rate,
output_path=output_path.as_posix(),
)
cmd = shlex.split(cmd)
result_path = output_path / "learned_embeds.bin"
captured = ""
start_button = st.button("Start")
download_button = st.empty()
if start_button:
with st.spinner("Training..."):
placeholder = st.empty()
p = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8"
)
while line := p.stderr.readline():
captured += line
placeholder.code(captured, language="bash")
if not result_path.exists():
st.stop()
# fix unknown file volume bug
trained_emb = torch.load(result_path, map_location="cpu")
for k, v in trained_emb.items():
trained_emb[k] = torch.from_numpy(v.numpy())
torch.save(trained_emb, result_path)
file = result_path.read_bytes()
download_button.download_button(f"Download {emb_name}.pt", file, f"{emb_name}.pt")
st.download_button(f"Download {emb_name}.pt ", file, f"{emb_name}.pt")
dataset_temp.cleanup()
output_temp.cleanup()