matsudatkm commited on
Commit
4d9b586
·
1 Parent(s): e2044cd

Textual Inversion

Browse files
Files changed (5) hide show
  1. app.py +44 -19
  2. models/diffusion.py +66 -1
  3. models/unet.py +6 -6
  4. train.py +61 -0
  5. utils.py +5 -6
app.py CHANGED
@@ -7,10 +7,13 @@ from PIL import Image
7
  from streamlit_drawable_canvas import st_canvas
8
 
9
  from models.diffusion import Diffusion
 
10
  from utils import initialize_data_dir, save_image
11
 
12
  # 設定
13
- st.set_page_config(page_title="手書き文字生成アプリ", layout="wide")
 
 
14
 
15
  # タイトル
16
  st.title("手書き文字生成アプリ")
@@ -32,7 +35,7 @@ num_samples = 5 # 各文字のサンプル数
32
  st.header("手書き文字を描いてください")
33
 
34
  # 保存用ディレクトリ
35
- data_dir = initialize_data_dir()
36
 
37
  # 描画領域の作成
38
  for char in characters:
@@ -53,33 +56,42 @@ for char in characters:
53
  if canvas.image_data is not None:
54
  img = Image.fromarray(
55
  canvas.image_data.astype("uint8"), "RGBA"
56
- ).convert("L")
57
- # 二値化
58
- img = img.point(lambda x: 0 if x < 128 else 255, "1")
59
  save_image(img, char, i, data_dir)
60
 
61
  # ハイパーパラメータの入力
62
- st.sidebar.header("ハイパーパラメータ設定")
63
  learning_rate = st.sidebar.number_input(
64
  "学習率", min_value=0.0001, max_value=1.0, value=0.001, step=0.0001, format="%.4f"
65
  )
66
  epochs = st.sidebar.number_input(
67
  "エポック数", min_value=1, max_value=100, value=10, step=1
68
  )
69
- optimizer_name = st.sidebar.selectbox("最適化手法", ["SGD", "Adam", "RMSprop"])
70
-
 
 
 
 
 
 
 
 
 
 
71
 
72
  # サンプリングの設定
73
  st.sidebar.header("サンプリング設定")
74
  noise_steps = st.sidebar.number_input(
75
- "ノイズステップ数", min_value=1, max_value=1000, value=1000, step=1
76
- )
77
- beta_start = st.sidebar.number_input(
78
- "βの初期値", min_value=0.0, max_value=1.0, value=0.0001, step=0.0001, format="%.4f"
79
  )
80
- beta_end = st.sidebar.number_input(
81
- "βの終了値", min_value=0.0, max_value=1.0, value=0.02, step=0.0001, format="%.4f"
82
  )
 
 
83
 
84
 
85
  # 生成ボタン
@@ -107,26 +119,39 @@ if st.button("生成"):
107
  torch.load(save_path, weights_only=True, map_location=device)
108
  )
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  # フォント画像の生成と表示
111
  def chuncked(iterable, n):
112
  for i in range(0, len(iterable), n):
113
  yield iterable[i : i + n]
114
 
115
- labels = list(range(46))
116
  columns_per_row = 5
117
  start_time = time.time()
118
  with st.spinner("フォント画像を生成中..."):
119
  labels_tensor = torch.tensor(labels).to(device)
120
- font_image = diffusion_model.sample(
121
- diffusion_model.model, labels_tensor
122
- )
123
  elapsed_time = time.time() - start_time
124
  st.success(f"フォント画像の生成に成功しました({elapsed_time:.2f}秒)")
125
  for label_row in chuncked(labels, columns_per_row):
126
  cols = st.columns(columns_per_row)
127
  for col, label in zip(cols, label_row):
 
128
  col.image(
129
- font_image[label].permute(1, 2, 0).cpu().numpy(),
130
  caption=f"{label}",
131
  use_container_width=True,
132
  )
 
7
  from streamlit_drawable_canvas import st_canvas
8
 
9
  from models.diffusion import Diffusion
10
+ from train import finetune
11
  from utils import initialize_data_dir, save_image
12
 
13
  # 設定
14
+ st.set_page_config(
15
+ page_title="手書き文字生成アプリ", layout="wide", page_icon=":pencil:"
16
+ )
17
 
18
  # タイトル
19
  st.title("手書き文字生成アプリ")
 
35
  st.header("手書き文字を描いてください")
36
 
37
  # 保存用ディレクトリ
38
+ data_dir = initialize_data_dir("./sample_images")
39
 
40
  # 描画領域の作成
41
  for char in characters:
 
56
  if canvas.image_data is not None:
57
  img = Image.fromarray(
58
  canvas.image_data.astype("uint8"), "RGBA"
59
+ ).convert("RGB")
60
+ # 輝度を逆転
61
+ img = Image.eval(img, lambda x: 255 - x)
62
  save_image(img, char, i, data_dir)
63
 
64
  # ハイパーパラメータの入力
65
+ st.sidebar.header("学習パラメータ設定")
66
  learning_rate = st.sidebar.number_input(
67
  "学習率", min_value=0.0001, max_value=1.0, value=0.001, step=0.0001, format="%.4f"
68
  )
69
  epochs = st.sidebar.number_input(
70
  "エポック数", min_value=1, max_value=100, value=10, step=1
71
  )
72
+ optimizer_name = st.sidebar.selectbox("最適化手法", ["SGD", "Adam", "RMSprop", "AdamW"])
73
+ if optimizer_name == "SGD":
74
+ optimizer = torch.optim.SGD
75
+ elif optimizer_name == "Adam":
76
+ optimizer = torch.optim.Adam
77
+ elif optimizer_name == "RMSprop":
78
+ optimizer = torch.optim.RMSprop
79
+ elif optimizer_name == "AdamW":
80
+ optimizer = torch.optim.AdamW
81
+ num_augmentations = st.sidebar.number_input(
82
+ "データ拡張回数", min_value=0, max_value=100, value=20, step=1
83
+ )
84
 
85
  # サンプリングの設定
86
  st.sidebar.header("サンプリング設定")
87
  noise_steps = st.sidebar.number_input(
88
+ "ノイズステップ数", min_value=2, max_value=1000, value=1000, step=1
 
 
 
89
  )
90
+ num_chars = st.sidebar.number_input(
91
+ "生成文字数", min_value=1, max_value=46, value=5, step=1
92
  )
93
+ beta_start = 0.0001
94
+ beta_end = 0.02
95
 
96
 
97
  # 生成ボタン
 
119
  torch.load(save_path, weights_only=True, map_location=device)
120
  )
121
 
122
+ # ファインチューニング
123
+ progress_bar = st.progress(0, text="学習中...")
124
+ diffusion_model = finetune(
125
+ data_dir,
126
+ diffusion_model,
127
+ criterion=torch.nn.MSELoss(),
128
+ optimizer=torch.optim.AdamW,
129
+ num_epochs=epochs,
130
+ learning_rate=learning_rate,
131
+ num_augmentations=num_augmentations,
132
+ progress_bar=progress_bar,
133
+ )
134
+ progress_bar.empty()
135
+
136
  # フォント画像の生成と表示
137
  def chuncked(iterable, n):
138
  for i in range(0, len(iterable), n):
139
  yield iterable[i : i + n]
140
 
141
+ labels = list(range(num_chars))
142
  columns_per_row = 5
143
  start_time = time.time()
144
  with st.spinner("フォント画像を生成中..."):
145
  labels_tensor = torch.tensor(labels).to(device)
146
+ font_image = diffusion_model.sample(diffusion_model.model, labels_tensor)
 
 
147
  elapsed_time = time.time() - start_time
148
  st.success(f"フォント画像の生成に成功しました({elapsed_time:.2f}秒)")
149
  for label_row in chuncked(labels, columns_per_row):
150
  cols = st.columns(columns_per_row)
151
  for col, label in zip(cols, label_row):
152
+ # 輝度を逆転させて表示
153
  col.image(
154
+ 255 - font_image[label].permute(1, 2, 0).cpu().numpy(),
155
  caption=f"{label}",
156
  use_container_width=True,
157
  )
models/diffusion.py CHANGED
@@ -1,5 +1,7 @@
1
  import numpy as np
2
  import torch
 
 
3
  from .unet import UNet_conditional
4
 
5
 
@@ -49,7 +51,9 @@ class Diffusion:
49
  ノイズ画像とノイズの生成
50
  """
51
  sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
52
- sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
 
 
53
  noise = torch.randn_like(x)
54
  return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * noise, noise
55
 
@@ -194,3 +198,64 @@ class Diffusion:
194
  history = np.vstack([history, item])
195
 
196
  return history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
  import torch
3
+ from streamlit.delta_generator import DeltaGenerator
4
+
5
  from .unet import UNet_conditional
6
 
7
 
 
51
  ノイズ画像とノイズの生成
52
  """
53
  sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
54
+ sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[
55
+ :, None, None, None
56
+ ]
57
  noise = torch.randn_like(x)
58
  return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * noise, noise
59
 
 
198
  history = np.vstack([history, item])
199
 
200
  return history
201
+
202
+ def fit_s(
203
+ self,
204
+ criterion: torch.nn.Module,
205
+ optimizer: torch.optim.Optimizer,
206
+ num_epochs: int,
207
+ learning_rate: float,
208
+ data_loader: torch.utils.data.DataLoader,
209
+ device: torch.device,
210
+ progress_bar: DeltaGenerator,
211
+ ) -> None:
212
+ """
213
+ Textual Inversionの学習
214
+ """
215
+ # 最小損失の初期化
216
+ min_train_loss = 9e9
217
+
218
+ # sのみを学習可能なパラメータとして定義
219
+ for params in self.model.parameters():
220
+ params.requires_grad = False
221
+ self.model.s = torch.nn.Parameter(
222
+ 1e-2 * torch.randn(1, self.time_dim, device=device)
223
+ )
224
+ print("sの初期値: ", self.model.s)
225
+ optimizer = optimizer([self.model.s], lr=learning_rate)
226
+ for epoch in range(num_epochs):
227
+ print(f"Epoch {epoch + 1} / {num_epochs}")
228
+ train_loss = 0
229
+ n_train = 0
230
+
231
+ self.model.eval()
232
+ for x, labels in data_loader:
233
+ train_batch_size = len(labels)
234
+ n_train += train_batch_size
235
+
236
+ x = x.to(device)
237
+ labels = labels.to(device)
238
+
239
+ t = self.sample_timesteps(x.size(0)).to(device)
240
+ xt, noise = self.noise_images(x, t)
241
+
242
+ optimizer.zero_grad()
243
+ predicted_noise = self.model(xt, t, labels, self.model.s)
244
+ loss = criterion(predicted_noise, noise)
245
+ loss.backward()
246
+ optimizer.step()
247
+
248
+ train_loss += loss.item() * train_batch_size
249
+
250
+ # 損失計算
251
+ avg_train_loss = train_loss / n_train
252
+
253
+ # 最小損失の更新とsの保存
254
+ if avg_train_loss < min_train_loss:
255
+ min_train_loss = avg_train_loss
256
+
257
+ # 結果表示
258
+ print(f"Epoch {epoch + 1}, Train loss: {avg_train_loss:.3f}")
259
+ progress_bar.progress((epoch + 1) / num_epochs, text=f"学習中... (train loss = {avg_train_loss:.3f})")
260
+
261
+ print("sの最終値: ", self.model.s)
models/unet.py CHANGED
@@ -174,14 +174,14 @@ class UNet_conditional(UNet):
174
  super().__init__(c_in, c_out, time_dim, **kwargs)
175
  self.label_emb = nn.Embedding(num_classes, time_dim)
176
 
177
- def forward(self, x, t, y):
178
- """
179
- ラベルの埋め込みをタイムステップの埋め込みに加算
180
- """
181
  t = t.unsqueeze(-1)
182
  t = self.pos_encoding(t, self.time_dim)
183
 
184
- if y is not None:
185
- t += self.label_emb(y)
 
 
186
 
187
  return self.unet_forwad(x, t)
 
174
  super().__init__(c_in, c_out, time_dim, **kwargs)
175
  self.label_emb = nn.Embedding(num_classes, time_dim)
176
 
177
+ # ラベルの埋め込みとPsudo Wordをタイムステップの埋め込みに加算
178
+ def forward(self, x, t, labels=None, s=None):
 
 
179
  t = t.unsqueeze(-1)
180
  t = self.pos_encoding(t, self.time_dim)
181
 
182
+ if labels is not None:
183
+ t += self.label_emb(labels)
184
+ if s is not None:
185
+ t += s
186
 
187
  return self.unet_forwad(x, t)
train.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from streamlit.delta_generator import DeltaGenerator
2
+ import torch
3
+ from torch.utils.data import DataLoader
4
+ from torchvision import datasets, transforms
5
+
6
+ from models.diffusion import Diffusion
7
+
8
+
9
+ def finetune(
10
+ data_dir: str,
11
+ model: Diffusion,
12
+ criterion: torch.nn.Module,
13
+ optimizer: torch.optim.Optimizer,
14
+ num_epochs: int,
15
+ learning_rate: float,
16
+ num_augmentations: int,
17
+ progress_bar: DeltaGenerator,
18
+ ) -> Diffusion:
19
+ # transformの定義
20
+ transform = transforms.Compose(
21
+ [
22
+ transforms.Grayscale(num_output_channels=3),
23
+ transforms.Resize((32, 32)),
24
+ transforms.ToTensor(),
25
+ transforms.Normalize(0.5, 0.5),
26
+ ]
27
+ )
28
+
29
+ transform_aug = transforms.Compose(
30
+ [
31
+ transforms.Grayscale(num_output_channels=3),
32
+ transforms.Resize((32, 32)),
33
+ transforms.RandomAffine(
34
+ degrees=0, translate=(0.05, 0.05), scale=(0.95, 1.05), fill=255
35
+ ),
36
+ transforms.ToTensor(),
37
+ transforms.Normalize(0.5, 0.5),
38
+ ]
39
+ )
40
+ # データセットの読み込み
41
+ sample_dataset = datasets.ImageFolder(data_dir, transform=transform)
42
+
43
+ # データ拡張
44
+ for _ in range(num_augmentations):
45
+ sample_dataset += datasets.ImageFolder(data_dir, transform=transform_aug)
46
+
47
+ # データローダーの生成
48
+ data_loader = DataLoader(sample_dataset, batch_size=32, shuffle=True)
49
+
50
+ # モデルのファインチューニング
51
+ model.fit_s(
52
+ criterion=criterion,
53
+ optimizer=optimizer,
54
+ num_epochs=num_epochs,
55
+ learning_rate=learning_rate,
56
+ data_loader=data_loader,
57
+ device=model.device,
58
+ progress_bar=progress_bar,
59
+ )
60
+
61
+ return model
utils.py CHANGED
@@ -1,13 +1,12 @@
1
  import os
2
- import tempfile
 
3
  from PIL import Image
4
 
5
 
6
- def initialize_data_dir() -> str:
7
- if "data_dir" not in os.environ:
8
- data_dir = tempfile.mkdtemp()
9
- os.environ["data_dir"] = data_dir
10
- return os.environ["data_dir"]
11
 
12
 
13
  def save_image(img, char, idx, data_dir):
 
1
  import os
2
+ from pathlib import Path
3
+
4
  from PIL import Image
5
 
6
 
7
+ def initialize_data_dir(data_dir: str) -> str:
8
+ Path(data_dir).mkdir(parents=True, exist_ok=True)
9
+ return data_dir
 
 
10
 
11
 
12
  def save_image(img, char, idx, data_dir):