Bingsu commited on
Commit
9225658
1 Parent(s): c04fa1a

feat: capture progressbar

Browse files
Files changed (1) hide show
  1. app.py +35 -17
app.py CHANGED
@@ -1,13 +1,14 @@
1
  from __future__ import annotations
2
 
 
3
  import shutil
4
  import subprocess
5
  from pathlib import Path
6
  from textwrap import dedent
7
 
8
- import torch
9
- import streamlit as st
10
  import numpy as np
 
 
11
  from PIL import Image
12
  from transformers import CLIPTokenizer
13
 
@@ -22,6 +23,7 @@ color = col1.color_picker("Pick a color", "#00f900")
22
  col2.text_input("", color, disabled=True)
23
 
24
  emb_name = st.text_input("Embedding name", color.lstrip("#").upper())
 
25
  rgb = hex_to_rgb(color)
26
 
27
  img_array = np.zeros((128, 128, 3), dtype=np.uint8)
@@ -38,23 +40,22 @@ if output_path.exists():
38
  dataset_path.mkdir()
39
  img_path = dataset_path / f"{emb_name}.png"
40
  Image.fromarray(img_array).save(img_path)
41
- tokenizer = CLIPTokenizer.from_pretrained(
42
- "Linaqruf/anything-v3.0", subfolder="tokenizer"
43
- )
44
 
45
  with st.sidebar:
46
- init_text = st.text_input("Initializer", "init token name")
47
  steps = st.slider("Steps", 1, 100, 30, step=1)
48
  learning_rate = st.text_input("Learning rate", "0.005")
49
  learning_rate = float(learning_rate)
50
 
51
- # case 1: init_text is not a single token
52
- token = tokenizer.tokenize(init_text)
 
 
53
  if len(token) > 1:
54
- st.warning("init_text must be a single token")
55
  st.stop()
56
 
57
- # case 2: init_text already exists in the tokenizer
58
  num_added_tokens = tokenizer.add_tokens(emb_name)
59
  if num_added_tokens == 0:
60
  st.warning(f"The tokenizer already contains the token {emb_name}")
@@ -62,7 +63,7 @@ if num_added_tokens == 0:
62
 
63
  cmd = """
64
  accelerate launch textual_inversion.py \
65
- --pretrained_model_name_or_path="Linaqruf/anything-v3.0" \
66
  --train_data_dir="dataset" \
67
  --learnable_property="style" \
68
  --placeholder_token="{emb_name}" \
@@ -78,22 +79,39 @@ accelerate launch textual_inversion.py \
78
  """.strip()
79
 
80
  cmd = dedent(cmd).format(
81
- emb_name=emb_name, init=init_text, lr=learning_rate, steps=steps
 
 
 
 
82
  )
 
 
 
 
83
 
84
- if st.button("Start"):
 
 
 
85
  with st.spinner("Training..."):
86
- subprocess.run(cmd, shell=True)
 
 
 
 
 
 
 
87
 
88
- result_path = Path("output") / "learned_embeds.bin"
89
  if not result_path.exists():
90
  st.stop()
91
 
92
- # fix unknown error
93
  trained_emb = torch.load(result_path, map_location="cpu")
94
  for k, v in trained_emb.items():
95
  trained_emb[k] = torch.from_numpy(v.numpy())
96
  torch.save(trained_emb, result_path)
97
 
98
  file = result_path.read_bytes()
99
- st.download_button("Download", file, f"{emb_name}.pt")
 
1
  from __future__ import annotations
2
 
3
+ import shlex
4
  import shutil
5
  import subprocess
6
  from pathlib import Path
7
  from textwrap import dedent
8
 
 
 
9
  import numpy as np
10
+ import streamlit as st
11
+ import torch
12
  from PIL import Image
13
  from transformers import CLIPTokenizer
14
 
 
23
  col2.text_input("", color, disabled=True)
24
 
25
  emb_name = st.text_input("Embedding name", color.lstrip("#").upper())
26
+ init_token = st.text_input("Initializer token", "init token name")
27
  rgb = hex_to_rgb(color)
28
 
29
  img_array = np.zeros((128, 128, 3), dtype=np.uint8)
 
40
  dataset_path.mkdir()
41
  img_path = dataset_path / f"{emb_name}.png"
42
  Image.fromarray(img_array).save(img_path)
 
 
 
43
 
44
  with st.sidebar:
45
+ model_name = st.text_input("Model name", "Linaqruf/anything-v3.0")
46
  steps = st.slider("Steps", 1, 100, 30, step=1)
47
  learning_rate = st.text_input("Learning rate", "0.005")
48
  learning_rate = float(learning_rate)
49
 
50
+ tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
51
+
52
+ # case 1: init_token is not a single token
53
+ token = tokenizer.tokenize(init_token)
54
  if len(token) > 1:
55
+ st.warning("Initializer token must be a single token")
56
  st.stop()
57
 
58
+ # case 2: init_token already exists in the tokenizer
59
  num_added_tokens = tokenizer.add_tokens(emb_name)
60
  if num_added_tokens == 0:
61
  st.warning(f"The tokenizer already contains the token {emb_name}")
 
63
 
64
  cmd = """
65
  accelerate launch textual_inversion.py \
66
+ --pretrained_model_name_or_path={model_name} \
67
  --train_data_dir="dataset" \
68
  --learnable_property="style" \
69
  --placeholder_token="{emb_name}" \
 
79
  """.strip()
80
 
81
  cmd = dedent(cmd).format(
82
+ model_name=model_name,
83
+ emb_name=emb_name,
84
+ init=init_token,
85
+ lr=learning_rate,
86
+ steps=steps,
87
  )
88
+ cmd = shlex.split(cmd)
89
+
90
+ result_path = output_path / "learned_embeds.bin"
91
+ captured = ""
92
 
93
+ start_button = st.button("Start")
94
+ download_button = st.empty()
95
+
96
+ if start_button:
97
  with st.spinner("Training..."):
98
+ placeholder = st.empty()
99
+ p = subprocess.Popen(
100
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8"
101
+ )
102
+
103
+ while line := p.stderr.readline():
104
+ captured += line
105
+ placeholder.code(captured, language="bash")
106
 
 
107
  if not result_path.exists():
108
  st.stop()
109
 
110
+ # fix unknown file volume bug
111
  trained_emb = torch.load(result_path, map_location="cpu")
112
  for k, v in trained_emb.items():
113
  trained_emb[k] = torch.from_numpy(v.numpy())
114
  torch.save(trained_emb, result_path)
115
 
116
  file = result_path.read_bytes()
117
+ download_button.download_button(f"Download {emb_name}.pt", file, f"{emb_name}.pt")