Spaces:
Sleeping
Sleeping
Commit
·
4d9b586
1
Parent(s):
e2044cd
Textual Inversion
Browse files- app.py +44 -19
- models/diffusion.py +66 -1
- models/unet.py +6 -6
- train.py +61 -0
- utils.py +5 -6
app.py
CHANGED
@@ -7,10 +7,13 @@ from PIL import Image
|
|
7 |
from streamlit_drawable_canvas import st_canvas
|
8 |
|
9 |
from models.diffusion import Diffusion
|
|
|
10 |
from utils import initialize_data_dir, save_image
|
11 |
|
12 |
# 設定
|
13 |
-
st.set_page_config(
|
|
|
|
|
14 |
|
15 |
# タイトル
|
16 |
st.title("手書き文字生成アプリ")
|
@@ -32,7 +35,7 @@ num_samples = 5 # 各文字のサンプル数
|
|
32 |
st.header("手書き文字を描いてください")
|
33 |
|
34 |
# 保存用ディレクトリ
|
35 |
-
data_dir = initialize_data_dir()
|
36 |
|
37 |
# 描画領域の作成
|
38 |
for char in characters:
|
@@ -53,33 +56,42 @@ for char in characters:
|
|
53 |
if canvas.image_data is not None:
|
54 |
img = Image.fromarray(
|
55 |
canvas.image_data.astype("uint8"), "RGBA"
|
56 |
-
).convert("
|
57 |
-
#
|
58 |
-
img =
|
59 |
save_image(img, char, i, data_dir)
|
60 |
|
61 |
# ハイパーパラメータの入力
|
62 |
-
st.sidebar.header("
|
63 |
learning_rate = st.sidebar.number_input(
|
64 |
"学習率", min_value=0.0001, max_value=1.0, value=0.001, step=0.0001, format="%.4f"
|
65 |
)
|
66 |
epochs = st.sidebar.number_input(
|
67 |
"エポック数", min_value=1, max_value=100, value=10, step=1
|
68 |
)
|
69 |
-
optimizer_name = st.sidebar.selectbox("最適化手法", ["SGD", "Adam", "RMSprop"])
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
# サンプリングの設定
|
73 |
st.sidebar.header("サンプリング設定")
|
74 |
noise_steps = st.sidebar.number_input(
|
75 |
-
"ノイズステップ数", min_value=
|
76 |
-
)
|
77 |
-
beta_start = st.sidebar.number_input(
|
78 |
-
"βの初期値", min_value=0.0, max_value=1.0, value=0.0001, step=0.0001, format="%.4f"
|
79 |
)
|
80 |
-
|
81 |
-
"
|
82 |
)
|
|
|
|
|
83 |
|
84 |
|
85 |
# 生成ボタン
|
@@ -107,26 +119,39 @@ if st.button("生成"):
|
|
107 |
torch.load(save_path, weights_only=True, map_location=device)
|
108 |
)
|
109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
# フォント画像の生成と表示
|
111 |
def chuncked(iterable, n):
|
112 |
for i in range(0, len(iterable), n):
|
113 |
yield iterable[i : i + n]
|
114 |
|
115 |
-
labels = list(range(
|
116 |
columns_per_row = 5
|
117 |
start_time = time.time()
|
118 |
with st.spinner("フォント画像を生成中..."):
|
119 |
labels_tensor = torch.tensor(labels).to(device)
|
120 |
-
font_image = diffusion_model.sample(
|
121 |
-
diffusion_model.model, labels_tensor
|
122 |
-
)
|
123 |
elapsed_time = time.time() - start_time
|
124 |
st.success(f"フォント画像の生成に成功しました({elapsed_time:.2f}秒)")
|
125 |
for label_row in chuncked(labels, columns_per_row):
|
126 |
cols = st.columns(columns_per_row)
|
127 |
for col, label in zip(cols, label_row):
|
|
|
128 |
col.image(
|
129 |
-
font_image[label].permute(1, 2, 0).cpu().numpy(),
|
130 |
caption=f"{label}",
|
131 |
use_container_width=True,
|
132 |
)
|
|
|
7 |
from streamlit_drawable_canvas import st_canvas
|
8 |
|
9 |
from models.diffusion import Diffusion
|
10 |
+
from train import finetune
|
11 |
from utils import initialize_data_dir, save_image
|
12 |
|
13 |
# 設定
|
14 |
+
st.set_page_config(
|
15 |
+
page_title="手書き文字生成アプリ", layout="wide", page_icon=":pencil:"
|
16 |
+
)
|
17 |
|
18 |
# タイトル
|
19 |
st.title("手書き文字生成アプリ")
|
|
|
35 |
st.header("手書き文字を描いてください")
|
36 |
|
37 |
# 保存用ディレクトリ
|
38 |
+
data_dir = initialize_data_dir("./sample_images")
|
39 |
|
40 |
# 描画領域の作成
|
41 |
for char in characters:
|
|
|
56 |
if canvas.image_data is not None:
|
57 |
img = Image.fromarray(
|
58 |
canvas.image_data.astype("uint8"), "RGBA"
|
59 |
+
).convert("RGB")
|
60 |
+
# 輝度を逆転
|
61 |
+
img = Image.eval(img, lambda x: 255 - x)
|
62 |
save_image(img, char, i, data_dir)
|
63 |
|
64 |
# ハイパーパラメータの入力
|
65 |
+
st.sidebar.header("学習パラメータ設定")
|
66 |
learning_rate = st.sidebar.number_input(
|
67 |
"学習率", min_value=0.0001, max_value=1.0, value=0.001, step=0.0001, format="%.4f"
|
68 |
)
|
69 |
epochs = st.sidebar.number_input(
|
70 |
"エポック数", min_value=1, max_value=100, value=10, step=1
|
71 |
)
|
72 |
+
optimizer_name = st.sidebar.selectbox("最適化手法", ["SGD", "Adam", "RMSprop", "AdamW"])
|
73 |
+
if optimizer_name == "SGD":
|
74 |
+
optimizer = torch.optim.SGD
|
75 |
+
elif optimizer_name == "Adam":
|
76 |
+
optimizer = torch.optim.Adam
|
77 |
+
elif optimizer_name == "RMSprop":
|
78 |
+
optimizer = torch.optim.RMSprop
|
79 |
+
elif optimizer_name == "AdamW":
|
80 |
+
optimizer = torch.optim.AdamW
|
81 |
+
num_augmentations = st.sidebar.number_input(
|
82 |
+
"データ拡張回数", min_value=0, max_value=100, value=20, step=1
|
83 |
+
)
|
84 |
|
85 |
# サンプリングの設定
|
86 |
st.sidebar.header("サンプリング設定")
|
87 |
noise_steps = st.sidebar.number_input(
|
88 |
+
"ノイズステップ数", min_value=2, max_value=1000, value=1000, step=1
|
|
|
|
|
|
|
89 |
)
|
90 |
+
num_chars = st.sidebar.number_input(
|
91 |
+
"生成文字数", min_value=1, max_value=46, value=5, step=1
|
92 |
)
|
93 |
+
beta_start = 0.0001
|
94 |
+
beta_end = 0.02
|
95 |
|
96 |
|
97 |
# 生成ボタン
|
|
|
119 |
torch.load(save_path, weights_only=True, map_location=device)
|
120 |
)
|
121 |
|
122 |
+
# ファインチューニング
|
123 |
+
progress_bar = st.progress(0, text="学習中...")
|
124 |
+
diffusion_model = finetune(
|
125 |
+
data_dir,
|
126 |
+
diffusion_model,
|
127 |
+
criterion=torch.nn.MSELoss(),
|
128 |
+
optimizer=torch.optim.AdamW,
|
129 |
+
num_epochs=epochs,
|
130 |
+
learning_rate=learning_rate,
|
131 |
+
num_augmentations=num_augmentations,
|
132 |
+
progress_bar=progress_bar,
|
133 |
+
)
|
134 |
+
progress_bar.empty()
|
135 |
+
|
136 |
# フォント画像の生成と表示
|
137 |
def chuncked(iterable, n):
|
138 |
for i in range(0, len(iterable), n):
|
139 |
yield iterable[i : i + n]
|
140 |
|
141 |
+
labels = list(range(num_chars))
|
142 |
columns_per_row = 5
|
143 |
start_time = time.time()
|
144 |
with st.spinner("フォント画像を生成中..."):
|
145 |
labels_tensor = torch.tensor(labels).to(device)
|
146 |
+
font_image = diffusion_model.sample(diffusion_model.model, labels_tensor)
|
|
|
|
|
147 |
elapsed_time = time.time() - start_time
|
148 |
st.success(f"フォント画像の生成に成功しました({elapsed_time:.2f}秒)")
|
149 |
for label_row in chuncked(labels, columns_per_row):
|
150 |
cols = st.columns(columns_per_row)
|
151 |
for col, label in zip(cols, label_row):
|
152 |
+
# 輝度を逆転させて表示
|
153 |
col.image(
|
154 |
+
255 - font_image[label].permute(1, 2, 0).cpu().numpy(),
|
155 |
caption=f"{label}",
|
156 |
use_container_width=True,
|
157 |
)
|
models/diffusion.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
import numpy as np
|
2 |
import torch
|
|
|
|
|
3 |
from .unet import UNet_conditional
|
4 |
|
5 |
|
@@ -49,7 +51,9 @@ class Diffusion:
|
|
49 |
ノイズ画像とノイズの生成
|
50 |
"""
|
51 |
sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
|
52 |
-
sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[
|
|
|
|
|
53 |
noise = torch.randn_like(x)
|
54 |
return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * noise, noise
|
55 |
|
@@ -194,3 +198,64 @@ class Diffusion:
|
|
194 |
history = np.vstack([history, item])
|
195 |
|
196 |
return history
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import numpy as np
|
2 |
import torch
|
3 |
+
from streamlit.delta_generator import DeltaGenerator
|
4 |
+
|
5 |
from .unet import UNet_conditional
|
6 |
|
7 |
|
|
|
51 |
ノイズ画像とノイズの生成
|
52 |
"""
|
53 |
sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
|
54 |
+
sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[
|
55 |
+
:, None, None, None
|
56 |
+
]
|
57 |
noise = torch.randn_like(x)
|
58 |
return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * noise, noise
|
59 |
|
|
|
198 |
history = np.vstack([history, item])
|
199 |
|
200 |
return history
|
201 |
+
|
202 |
+
def fit_s(
|
203 |
+
self,
|
204 |
+
criterion: torch.nn.Module,
|
205 |
+
optimizer: torch.optim.Optimizer,
|
206 |
+
num_epochs: int,
|
207 |
+
learning_rate: float,
|
208 |
+
data_loader: torch.utils.data.DataLoader,
|
209 |
+
device: torch.device,
|
210 |
+
progress_bar: DeltaGenerator,
|
211 |
+
) -> None:
|
212 |
+
"""
|
213 |
+
Textual Inversionの学習
|
214 |
+
"""
|
215 |
+
# 最小損失の初期化
|
216 |
+
min_train_loss = 9e9
|
217 |
+
|
218 |
+
# sのみを学習可能なパラメータとして定義
|
219 |
+
for params in self.model.parameters():
|
220 |
+
params.requires_grad = False
|
221 |
+
self.model.s = torch.nn.Parameter(
|
222 |
+
1e-2 * torch.randn(1, self.time_dim, device=device)
|
223 |
+
)
|
224 |
+
print("sの初期値: ", self.model.s)
|
225 |
+
optimizer = optimizer([self.model.s], lr=learning_rate)
|
226 |
+
for epoch in range(num_epochs):
|
227 |
+
print(f"Epoch {epoch + 1} / {num_epochs}")
|
228 |
+
train_loss = 0
|
229 |
+
n_train = 0
|
230 |
+
|
231 |
+
self.model.eval()
|
232 |
+
for x, labels in data_loader:
|
233 |
+
train_batch_size = len(labels)
|
234 |
+
n_train += train_batch_size
|
235 |
+
|
236 |
+
x = x.to(device)
|
237 |
+
labels = labels.to(device)
|
238 |
+
|
239 |
+
t = self.sample_timesteps(x.size(0)).to(device)
|
240 |
+
xt, noise = self.noise_images(x, t)
|
241 |
+
|
242 |
+
optimizer.zero_grad()
|
243 |
+
predicted_noise = self.model(xt, t, labels, self.model.s)
|
244 |
+
loss = criterion(predicted_noise, noise)
|
245 |
+
loss.backward()
|
246 |
+
optimizer.step()
|
247 |
+
|
248 |
+
train_loss += loss.item() * train_batch_size
|
249 |
+
|
250 |
+
# 損失計算
|
251 |
+
avg_train_loss = train_loss / n_train
|
252 |
+
|
253 |
+
# 最小損失の更新とsの保存
|
254 |
+
if avg_train_loss < min_train_loss:
|
255 |
+
min_train_loss = avg_train_loss
|
256 |
+
|
257 |
+
# 結果表示
|
258 |
+
print(f"Epoch {epoch + 1}, Train loss: {avg_train_loss:.3f}")
|
259 |
+
progress_bar.progress((epoch + 1) / num_epochs, text=f"学習中... (train loss = {avg_train_loss:.3f})")
|
260 |
+
|
261 |
+
print("sの最終値: ", self.model.s)
|
models/unet.py
CHANGED
@@ -174,14 +174,14 @@ class UNet_conditional(UNet):
|
|
174 |
super().__init__(c_in, c_out, time_dim, **kwargs)
|
175 |
self.label_emb = nn.Embedding(num_classes, time_dim)
|
176 |
|
177 |
-
|
178 |
-
|
179 |
-
ラベルの埋め込みをタイムステップの埋め込みに加算
|
180 |
-
"""
|
181 |
t = t.unsqueeze(-1)
|
182 |
t = self.pos_encoding(t, self.time_dim)
|
183 |
|
184 |
-
if
|
185 |
-
t += self.label_emb(
|
|
|
|
|
186 |
|
187 |
return self.unet_forwad(x, t)
|
|
|
174 |
super().__init__(c_in, c_out, time_dim, **kwargs)
|
175 |
self.label_emb = nn.Embedding(num_classes, time_dim)
|
176 |
|
177 |
+
# ラベルの埋め込みとPsudo Wordをタイムステップの埋め込みに加算
|
178 |
+
def forward(self, x, t, labels=None, s=None):
|
|
|
|
|
179 |
t = t.unsqueeze(-1)
|
180 |
t = self.pos_encoding(t, self.time_dim)
|
181 |
|
182 |
+
if labels is not None:
|
183 |
+
t += self.label_emb(labels)
|
184 |
+
if s is not None:
|
185 |
+
t += s
|
186 |
|
187 |
return self.unet_forwad(x, t)
|
train.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from streamlit.delta_generator import DeltaGenerator
|
2 |
+
import torch
|
3 |
+
from torch.utils.data import DataLoader
|
4 |
+
from torchvision import datasets, transforms
|
5 |
+
|
6 |
+
from models.diffusion import Diffusion
|
7 |
+
|
8 |
+
|
9 |
+
def finetune(
|
10 |
+
data_dir: str,
|
11 |
+
model: Diffusion,
|
12 |
+
criterion: torch.nn.Module,
|
13 |
+
optimizer: torch.optim.Optimizer,
|
14 |
+
num_epochs: int,
|
15 |
+
learning_rate: float,
|
16 |
+
num_augmentations: int,
|
17 |
+
progress_bar: DeltaGenerator,
|
18 |
+
) -> Diffusion:
|
19 |
+
# transformの定義
|
20 |
+
transform = transforms.Compose(
|
21 |
+
[
|
22 |
+
transforms.Grayscale(num_output_channels=3),
|
23 |
+
transforms.Resize((32, 32)),
|
24 |
+
transforms.ToTensor(),
|
25 |
+
transforms.Normalize(0.5, 0.5),
|
26 |
+
]
|
27 |
+
)
|
28 |
+
|
29 |
+
transform_aug = transforms.Compose(
|
30 |
+
[
|
31 |
+
transforms.Grayscale(num_output_channels=3),
|
32 |
+
transforms.Resize((32, 32)),
|
33 |
+
transforms.RandomAffine(
|
34 |
+
degrees=0, translate=(0.05, 0.05), scale=(0.95, 1.05), fill=255
|
35 |
+
),
|
36 |
+
transforms.ToTensor(),
|
37 |
+
transforms.Normalize(0.5, 0.5),
|
38 |
+
]
|
39 |
+
)
|
40 |
+
# データセットの読み込み
|
41 |
+
sample_dataset = datasets.ImageFolder(data_dir, transform=transform)
|
42 |
+
|
43 |
+
# データ拡張
|
44 |
+
for _ in range(num_augmentations):
|
45 |
+
sample_dataset += datasets.ImageFolder(data_dir, transform=transform_aug)
|
46 |
+
|
47 |
+
# データローダーの生成
|
48 |
+
data_loader = DataLoader(sample_dataset, batch_size=32, shuffle=True)
|
49 |
+
|
50 |
+
# モデルのファインチューニング
|
51 |
+
model.fit_s(
|
52 |
+
criterion=criterion,
|
53 |
+
optimizer=optimizer,
|
54 |
+
num_epochs=num_epochs,
|
55 |
+
learning_rate=learning_rate,
|
56 |
+
data_loader=data_loader,
|
57 |
+
device=model.device,
|
58 |
+
progress_bar=progress_bar,
|
59 |
+
)
|
60 |
+
|
61 |
+
return model
|
utils.py
CHANGED
@@ -1,13 +1,12 @@
|
|
1 |
import os
|
2 |
-
import
|
|
|
3 |
from PIL import Image
|
4 |
|
5 |
|
6 |
-
def initialize_data_dir() -> str:
|
7 |
-
|
8 |
-
|
9 |
-
os.environ["data_dir"] = data_dir
|
10 |
-
return os.environ["data_dir"]
|
11 |
|
12 |
|
13 |
def save_image(img, char, idx, data_dir):
|
|
|
1 |
import os
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
from PIL import Image
|
5 |
|
6 |
|
7 |
+
def initialize_data_dir(data_dir: str) -> str:
|
8 |
+
Path(data_dir).mkdir(parents=True, exist_ok=True)
|
9 |
+
return data_dir
|
|
|
|
|
10 |
|
11 |
|
12 |
def save_image(img, char, idx, data_dir):
|