Spaces:
Runtime error
Runtime error
File size: 4,476 Bytes
d9da3b4 9225658 d9da3b4 96bc74a d9da3b4 9225658 d9da3b4 9e1a19f d9da3b4 9225658 d9da3b4 96bc74a d9da3b4 9225658 72cb885 d9da3b4 9225658 d9da3b4 9225658 d9da3b4 9225658 d9da3b4 9225658 96bc74a d9da3b4 96bc74a d9da3b4 9225658 96bc74a 9225658 96bc74a d9da3b4 9225658 d9da3b4 9225658 d9da3b4 9225658 d9da3b4 9225658 d9da3b4 9225658 9e1a19f 96bc74a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
from __future__ import annotations
import shlex
import subprocess
from pathlib import Path
from tempfile import TemporaryDirectory
from textwrap import dedent
import numpy as np
import streamlit as st
import torch
from PIL import Image
from transformers import CLIPTokenizer
def hex_to_rgb(s: str) -> tuple[int, int, int]:
value = s.lstrip("#")
return (int(value[:2], 16), int(value[2:4], 16), int(value[4:6], 16))
st.header("Color Textual Inversion")
with st.expander(label="info"):
with open("info.txt", "r", encoding="utf-8") as f:
st.markdown(f.read())
duplicate_button = """<a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/Bingsu/color_textual_inversion?duplicate=true"><img style="margin: 0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a>"""
st.markdown(duplicate_button, unsafe_allow_html=True)
col1, col2 = st.columns([15, 85])
color = col1.color_picker("Pick a color", "#00f900")
col2.text_input("", color, disabled=True)
emb_name = st.text_input("Embedding name", color.lstrip("#").upper())
init_token = st.text_input("Initializer token", "init token name")
rgb = hex_to_rgb(color)
img_array = np.zeros((128, 128, 3), dtype=np.uint8)
for i in range(3):
img_array[..., i] = rgb[i]
dataset_temp = TemporaryDirectory(prefix="dataset_", dir=".")
dataset_path = Path(dataset_temp.name)
output_temp = TemporaryDirectory(prefix="output_", dir=".")
output_path = Path(output_temp.name)
img_path = dataset_path / f"{emb_name}.png"
Image.fromarray(img_array).save(img_path)
with st.sidebar:
model_name = st.text_input("Model name", "Linaqruf/anything-v3.0")
steps = st.slider("Steps", 1, 100, value=1, step=1)
learning_rate = st.text_input("Learning rate", "0.001")
learning_rate = float(learning_rate)
tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
# case 1: init_token is not a single token
token = tokenizer.tokenize(init_token)
if len(token) > 1:
st.warning("Initializer token must be a single token")
st.stop()
# case 2: init_token already exists in the tokenizer
num_added_tokens = tokenizer.add_tokens(emb_name)
if num_added_tokens == 0:
st.warning(f"The tokenizer already contains the token {emb_name}")
st.stop()
cmd = """
accelerate launch textual_inversion.py \
--pretrained_model_name_or_path={model_name} \
--train_data_dir={dataset_path} \
--learnable_property="style" \
--placeholder_token="{emb_name}" \
--initializer_token="{init}" \
--resolution=128 \
--train_batch_size=1 \
--repeats=1 \
--gradient_accumulation_steps=1 \
--max_train_steps={steps} \
--learning_rate={lr} \
--output_dir={output_path} \
--only_save_embeds
""".strip()
cmd = dedent(cmd).format(
model_name=model_name,
dataset_path=dataset_path.as_posix(),
emb_name=emb_name,
init=init_token,
steps=steps,
lr=learning_rate,
output_path=output_path.as_posix(),
)
cmd = shlex.split(cmd)
result_path = output_path / "learned_embeds.bin"
captured = ""
start_button = st.button("Start")
download_button = st.empty()
if start_button:
with st.spinner("Training..."):
placeholder = st.empty()
p = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8"
)
while line := p.stderr.readline():
captured += line
placeholder.code(captured, language="bash")
if not result_path.exists():
st.stop()
# fix unknown file volume bug
trained_emb = torch.load(result_path, map_location="cpu")
for k, v in trained_emb.items():
trained_emb[k] = torch.from_numpy(v.numpy())
torch.save(trained_emb, result_path)
file = result_path.read_bytes()
download_button.download_button(f"Download {emb_name}.pt", file, f"{emb_name}.pt")
st.download_button(f"Download {emb_name}.pt ", file, f"{emb_name}.pt")
dataset_temp.cleanup()
output_temp.cleanup()
|