from __future__ import annotations import shlex import shutil import subprocess from pathlib import Path from textwrap import dedent import numpy as np import streamlit as st import torch from PIL import Image from transformers import CLIPTokenizer def hex_to_rgb(s: str) -> tuple[int, int, int]: value = s.lstrip("#") return (int(value[:2], 16), int(value[2:4], 16), int(value[4:6], 16)) col1, col2 = st.columns([15, 85]) color = col1.color_picker("Pick a color", "#00f900") col2.text_input("", color, disabled=True) emb_name = st.text_input("Embedding name", color.lstrip("#").upper()) init_token = st.text_input("Initializer token", "init token name") rgb = hex_to_rgb(color) img_array = np.zeros((128, 128, 3), dtype=np.uint8) for i in range(3): img_array[..., i] = rgb[i] dataset_path = Path("dataset") output_path = Path("output") if dataset_path.exists(): shutil.rmtree(dataset_path) if output_path.exists(): shutil.rmtree(output_path) dataset_path.mkdir() img_path = dataset_path / f"{emb_name}.png" Image.fromarray(img_array).save(img_path) with st.sidebar: model_name = st.text_input("Model name", "Linaqruf/anything-v3.0") steps = st.slider("Steps", 1, 100, 30, step=1) learning_rate = st.text_input("Learning rate", "0.005") learning_rate = float(learning_rate) tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer") # case 1: init_token is not a single token token = tokenizer.tokenize(init_token) if len(token) > 1: st.warning("Initializer token must be a single token") st.stop() # case 2: init_token already exists in the tokenizer num_added_tokens = tokenizer.add_tokens(emb_name) if num_added_tokens == 0: st.warning(f"The tokenizer already contains the token {emb_name}") st.stop() cmd = """ accelerate launch textual_inversion.py \ --pretrained_model_name_or_path={model_name} \ --train_data_dir="dataset" \ --learnable_property="style" \ --placeholder_token="{emb_name}" \ --initializer_token="{init}" \ --resolution=128 \ --train_batch_size=1 \ --repeats=1 \ --gradient_accumulation_steps=1 \ --max_train_steps={steps} \ --learning_rate={lr} \ --output_dir="output" \ --only_save_embeds """.strip() cmd = dedent(cmd).format( model_name=model_name, emb_name=emb_name, init=init_token, lr=learning_rate, steps=steps, ) cmd = shlex.split(cmd) result_path = output_path / "learned_embeds.bin" captured = "" start_button = st.button("Start") download_button = st.empty() if start_button: with st.spinner("Training..."): placeholder = st.empty() p = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8" ) while line := p.stderr.readline(): captured += line placeholder.code(captured, language="bash") if not result_path.exists(): st.stop() # fix unknown file volume bug trained_emb = torch.load(result_path, map_location="cpu") for k, v in trained_emb.items(): trained_emb[k] = torch.from_numpy(v.numpy()) torch.save(trained_emb, result_path) file = result_path.read_bytes() download_button.download_button(f"Download {emb_name}.pt", file, f"{emb_name}.pt")