Spaces:
Sleeping
Sleeping
Commit
·
6f21460
1
Parent(s):
9586f4c
deploy mock
Browse files- .gitignore +3 -0
- README.md +1 -0
- app.py +138 -2
- models/__init__.py +0 -0
- models/diffusion.py +196 -0
- models/unet.py +187 -0
- requirements.txt +3 -0
- utils.py +39 -0
.gitignore
CHANGED
@@ -1,2 +1,5 @@
|
|
1 |
.venv*
|
2 |
__pycache__*
|
|
|
|
|
|
|
|
1 |
.venv*
|
2 |
__pycache__*
|
3 |
+
|
4 |
+
# large files
|
5 |
+
models/model.pth
|
README.md
CHANGED
@@ -5,6 +5,7 @@ colorFrom: red
|
|
5 |
colorTo: purple
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.40.1
|
|
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
|
|
5 |
colorTo: purple
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.40.1
|
8 |
+
python_version: 3.10.15
|
9 |
app_file: app.py
|
10 |
pinned: false
|
11 |
---
|
app.py
CHANGED
@@ -1,4 +1,140 @@
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import traceback
|
3 |
+
|
4 |
import streamlit as st
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
from streamlit_drawable_canvas import st_canvas
|
8 |
+
|
9 |
+
from models.diffusion import Diffusion
|
10 |
+
from utils import initialize_data_dir, save_image
|
11 |
+
|
12 |
+
# 設定
|
13 |
+
st.set_page_config(page_title="手書き文字生成アプリ", layout="wide")
|
14 |
+
|
15 |
+
# タイトル
|
16 |
+
st.title("手書き文字生成アプリ")
|
17 |
+
|
18 |
+
# 説明
|
19 |
+
st.markdown(
|
20 |
+
"""
|
21 |
+
このアプリでは、「あ」「い」「う」「え」「お」の手書き文字をそれぞれ5回ずつ描いてください。
|
22 |
+
「生成」ボタンを押すと、モデルがファインチューニングされ、生成された文字が表示されます。
|
23 |
+
学習率、エポック数、最適化手法も調整可能です。
|
24 |
+
"""
|
25 |
+
)
|
26 |
+
|
27 |
+
# 文字リスト
|
28 |
+
characters = ["あ", "い", "う", "え", "お"]
|
29 |
+
num_samples = 5 # 各文字のサンプル数
|
30 |
+
|
31 |
+
# 描画スペースの配置
|
32 |
+
st.header("手書き文字を描いてください")
|
33 |
+
|
34 |
+
# 保存用ディレクトリ
|
35 |
+
data_dir = initialize_data_dir()
|
36 |
+
|
37 |
+
# 描画領域の作成
|
38 |
+
for char in characters:
|
39 |
+
st.subheader(f"文字「{char}」を{num_samples}回描いてください")
|
40 |
+
cols = st.columns(num_samples)
|
41 |
+
for i in range(num_samples):
|
42 |
+
with cols[i]:
|
43 |
+
canvas = st_canvas(
|
44 |
+
fill_color="white",
|
45 |
+
stroke_width=3,
|
46 |
+
stroke_color="black",
|
47 |
+
background_color="white",
|
48 |
+
width=150,
|
49 |
+
height=150,
|
50 |
+
drawing_mode="freedraw",
|
51 |
+
key=f"{char}_{i}",
|
52 |
+
)
|
53 |
+
if canvas.image_data is not None:
|
54 |
+
img = Image.fromarray(
|
55 |
+
canvas.image_data.astype("uint8"), "RGBA"
|
56 |
+
).convert("L")
|
57 |
+
# 二値化
|
58 |
+
img = img.point(lambda x: 0 if x < 128 else 255, "1")
|
59 |
+
save_image(img, char, i, data_dir)
|
60 |
+
|
61 |
+
# ハイパーパラメータの入力
|
62 |
+
st.sidebar.header("ハイパーパラメータ設定")
|
63 |
+
learning_rate = st.sidebar.number_input(
|
64 |
+
"学習率", min_value=0.0001, max_value=1.0, value=0.001, step=0.0001, format="%.4f"
|
65 |
+
)
|
66 |
+
epochs = st.sidebar.number_input(
|
67 |
+
"エポック数", min_value=1, max_value=100, value=10, step=1
|
68 |
+
)
|
69 |
+
optimizer_name = st.sidebar.selectbox("最適化手法", ["SGD", "Adam", "RMSprop"])
|
70 |
+
|
71 |
+
|
72 |
+
# サンプリングの設定
|
73 |
+
st.sidebar.header("サンプリング設定")
|
74 |
+
noise_steps = st.sidebar.number_input(
|
75 |
+
"ノイズステップ数", min_value=1, max_value=1000, value=1000, step=1
|
76 |
+
)
|
77 |
+
beta_start = st.sidebar.number_input(
|
78 |
+
"βの初期値", min_value=0.0, max_value=1.0, value=0.0001, step=0.0001, format="%.4f"
|
79 |
+
)
|
80 |
+
beta_end = st.sidebar.number_input(
|
81 |
+
"βの終了値", min_value=0.0, max_value=1.0, value=0.02, step=0.0001, format="%.4f"
|
82 |
+
)
|
83 |
+
|
84 |
+
|
85 |
+
# 生成ボタン
|
86 |
+
if st.button("生成"):
|
87 |
+
try:
|
88 |
+
# デバイスの設定
|
89 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
90 |
+
|
91 |
+
# モデルインスタンスの生成
|
92 |
+
diffusion_model = Diffusion(
|
93 |
+
noise_steps=1000,
|
94 |
+
beta_start=1e-4,
|
95 |
+
beta_end=0.02,
|
96 |
+
img_size=32,
|
97 |
+
num_classes=46,
|
98 |
+
c_in=3,
|
99 |
+
c_out=3,
|
100 |
+
device=device,
|
101 |
+
)
|
102 |
+
|
103 |
+
save_path = "models/model.pth"
|
104 |
+
|
105 |
+
# モデルのロード
|
106 |
+
diffusion_model.model.load_state_dict(
|
107 |
+
torch.load(save_path, weights_only=True, map_location=device)
|
108 |
+
)
|
109 |
+
|
110 |
+
# フォント画像の生成と表示
|
111 |
+
def chuncked(iterable, n):
|
112 |
+
for i in range(0, len(iterable), n):
|
113 |
+
yield iterable[i : i + n]
|
114 |
+
|
115 |
+
labels = list(range(46))
|
116 |
+
columns_per_row = 5
|
117 |
+
start_time = time.time()
|
118 |
+
with st.spinner("フォント画像を生成中..."):
|
119 |
+
labels_tensor = torch.tensor(labels).to(device)
|
120 |
+
font_image = diffusion_model.sample(
|
121 |
+
diffusion_model.model, labels_tensor
|
122 |
+
)
|
123 |
+
elapsed_time = time.time() - start_time
|
124 |
+
st.success(f"フォント画像の生成に成功しました({elapsed_time:.2f}秒)")
|
125 |
+
for label_row in chuncked(labels, columns_per_row):
|
126 |
+
cols = st.columns(columns_per_row)
|
127 |
+
for col, label in zip(cols, label_row):
|
128 |
+
col.image(
|
129 |
+
font_image[label].permute(1, 2, 0).cpu().numpy(),
|
130 |
+
caption=f"{label}",
|
131 |
+
use_container_width=True,
|
132 |
+
)
|
133 |
|
134 |
+
except FileNotFoundError as e:
|
135 |
+
st.error(str(e))
|
136 |
+
except ValueError as e:
|
137 |
+
st.error(str(e))
|
138 |
+
except Exception as e:
|
139 |
+
st.error(f"予期せぬエラーが発生しました: {e}")
|
140 |
+
st.error(traceback.format_exc())
|
models/__init__.py
ADDED
File without changes
|
models/diffusion.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from .unet import UNet_conditional
|
4 |
+
|
5 |
+
|
6 |
+
class Diffusion:
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
noise_steps: int = 1000,
|
10 |
+
beta_start: float = 1e-4,
|
11 |
+
beta_end: float = 0.02,
|
12 |
+
img_size: int = 32,
|
13 |
+
num_classes: int = 46,
|
14 |
+
c_in: int = 3,
|
15 |
+
c_out: int = 3,
|
16 |
+
device: torch.device = torch.device("cuda"),
|
17 |
+
time_dim: int = 256,
|
18 |
+
**kwargs,
|
19 |
+
):
|
20 |
+
|
21 |
+
self.noise_steps = noise_steps
|
22 |
+
self.beta_start = beta_start
|
23 |
+
self.beta_end = beta_end
|
24 |
+
self.img_size = img_size
|
25 |
+
self.device = device
|
26 |
+
self.time_dim = time_dim
|
27 |
+
self.num_classes = num_classes
|
28 |
+
self.c_in = c_in
|
29 |
+
self.c_out = c_out
|
30 |
+
self.model = UNet_conditional(
|
31 |
+
c_in, c_out, time_dim, num_classes=num_classes, **kwargs
|
32 |
+
).to(device)
|
33 |
+
|
34 |
+
self.beta = self.prepare_noise_schedule().to(device)
|
35 |
+
self.alpha = 1.0 - self.beta
|
36 |
+
self.alpha_hat = torch.cumprod(self.alpha, dim=0)
|
37 |
+
|
38 |
+
def __call__(self, x, t, labels):
|
39 |
+
return self.model(x, t, labels)
|
40 |
+
|
41 |
+
def prepare_noise_schedule(self) -> torch.Tensor:
|
42 |
+
"""
|
43 |
+
ノイズスケジュールの生成
|
44 |
+
"""
|
45 |
+
return torch.linspace(self.beta_start, self.beta_end, self.noise_steps)
|
46 |
+
|
47 |
+
def noise_images(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
48 |
+
"""
|
49 |
+
ノイズ画像とノイズの生成
|
50 |
+
"""
|
51 |
+
sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
|
52 |
+
sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
|
53 |
+
noise = torch.randn_like(x)
|
54 |
+
return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * noise, noise
|
55 |
+
|
56 |
+
def sample_timesteps(self, n: int) -> torch.Tensor:
|
57 |
+
"""
|
58 |
+
タイムステップのサンプリング
|
59 |
+
"""
|
60 |
+
return torch.randint(low=1, high=self.noise_steps, size=(n,))
|
61 |
+
|
62 |
+
def sample(self, model: torch.nn.Module, labels: torch.Tensor) -> torch.Tensor:
|
63 |
+
"""
|
64 |
+
画像の生成
|
65 |
+
"""
|
66 |
+
self.model = model
|
67 |
+
n = len(labels)
|
68 |
+
print(f"Sampling {n} new images....")
|
69 |
+
model.eval()
|
70 |
+
|
71 |
+
with torch.no_grad():
|
72 |
+
x = torch.randn((n, self.c_in, self.img_size, self.img_size)).to(
|
73 |
+
self.device
|
74 |
+
)
|
75 |
+
for i in reversed(range(1, self.noise_steps)):
|
76 |
+
t = (torch.ones(n) * i).long().to(self.device)
|
77 |
+
predicted_noise = model(x, t, labels)
|
78 |
+
alpha = self.alpha[t][:, None, None, None]
|
79 |
+
alpha_hat = self.alpha_hat[t][:, None, None, None]
|
80 |
+
beta = self.beta[t][:, None, None, None]
|
81 |
+
if i > 1:
|
82 |
+
noise = torch.randn_like(x)
|
83 |
+
else:
|
84 |
+
noise = torch.zeros_like(x)
|
85 |
+
x = (
|
86 |
+
1
|
87 |
+
/ torch.sqrt(alpha)
|
88 |
+
* (
|
89 |
+
x
|
90 |
+
- ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise
|
91 |
+
)
|
92 |
+
+ torch.sqrt(beta) * noise
|
93 |
+
)
|
94 |
+
model.train()
|
95 |
+
x = (x.clamp(-1, 1) + 1) / 2
|
96 |
+
x = (x * 255).type(torch.uint8)
|
97 |
+
return x
|
98 |
+
|
99 |
+
def fit(
|
100 |
+
self,
|
101 |
+
optimizer: torch.optim.Optimizer,
|
102 |
+
criterion: torch.nn.Module,
|
103 |
+
num_epochs: int,
|
104 |
+
train_loader: torch.utils.data.DataLoader,
|
105 |
+
test_loader: torch.utils.data.DataLoader,
|
106 |
+
model: torch.nn.Module,
|
107 |
+
device: torch.device,
|
108 |
+
history: np.ndarray,
|
109 |
+
save_path: str,
|
110 |
+
) -> np.ndarray:
|
111 |
+
"""
|
112 |
+
モデルの学習
|
113 |
+
"""
|
114 |
+
base_epochs = len(history)
|
115 |
+
# 最小損失の初期化
|
116 |
+
min_test_loss = 9e9
|
117 |
+
|
118 |
+
for epoch in range(base_epochs, base_epochs + num_epochs):
|
119 |
+
# 1エポックあたりの累積損失(平均化前)
|
120 |
+
train_loss, test_loss = 0, 0
|
121 |
+
# 1エポックあたりのデータ累積件数
|
122 |
+
n_train, n_test = 0, 0
|
123 |
+
|
124 |
+
# 訓練フェーズ
|
125 |
+
self.model.train()
|
126 |
+
for x, labels in train_loader:
|
127 |
+
# 1バッチあたりのデータ件数
|
128 |
+
train_batch_size = len(labels)
|
129 |
+
# 1エポックあたりのデータ累積件数
|
130 |
+
n_train += train_batch_size
|
131 |
+
|
132 |
+
# GPUに転送
|
133 |
+
x = x.to(device)
|
134 |
+
labels = labels.to(device)
|
135 |
+
|
136 |
+
# ノイズステップを生成
|
137 |
+
t = self.sample_timesteps(x.size(0)).to(device)
|
138 |
+
# ノイズ画像とノイズを生成
|
139 |
+
xt, noise = self.noise_images(x, t)
|
140 |
+
|
141 |
+
# 勾配の初期化
|
142 |
+
optimizer.zero_grad()
|
143 |
+
# 予測計算
|
144 |
+
predicted_noise = model(xt, t, labels)
|
145 |
+
# 損失計算
|
146 |
+
loss = criterion(predicted_noise, noise)
|
147 |
+
# 勾配計算
|
148 |
+
loss.backward()
|
149 |
+
# パラメータ修正
|
150 |
+
optimizer.step()
|
151 |
+
|
152 |
+
# 平均前の損失計算
|
153 |
+
train_loss += loss.item() * train_batch_size
|
154 |
+
|
155 |
+
# 予測フェーズ
|
156 |
+
self.model.eval()
|
157 |
+
for x, labels in test_loader:
|
158 |
+
# 1バッチあたりのデータ件数
|
159 |
+
test_batch_size = len(labels)
|
160 |
+
# 1エポックあたりのデータ累積件数
|
161 |
+
n_test += test_batch_size
|
162 |
+
|
163 |
+
# GPUに転送
|
164 |
+
x = x.to(device)
|
165 |
+
labels = labels.to(device)
|
166 |
+
|
167 |
+
# ノイズステップを生成
|
168 |
+
t = self.sample_timesteps(x.size(0)).to(device)
|
169 |
+
# ノイズ画像とノイズを生成
|
170 |
+
xt, noise = self.noise_images(x, t)
|
171 |
+
# 予測計算
|
172 |
+
predicted_noise = model(xt, t, labels)
|
173 |
+
# 損失計算
|
174 |
+
loss = criterion(predicted_noise, noise)
|
175 |
+
|
176 |
+
# 平均前の損失計算
|
177 |
+
test_loss += loss.item() * test_batch_size
|
178 |
+
|
179 |
+
# 損失計算
|
180 |
+
avg_train_loss = train_loss / n_train
|
181 |
+
avg_test_loss = test_loss / n_test
|
182 |
+
|
183 |
+
# 最小損失の更新とモデルの保存
|
184 |
+
if avg_test_loss < min_test_loss:
|
185 |
+
min_test_loss = avg_test_loss
|
186 |
+
torch.save(self.model.state_dict(), save_path)
|
187 |
+
|
188 |
+
# 結果表示
|
189 |
+
print(
|
190 |
+
f"Epoch {epoch + 1}, Train loss: {avg_train_loss:.3f}, Test loss: {avg_test_loss:.3f}"
|
191 |
+
)
|
192 |
+
# 記録
|
193 |
+
item = np.array([epoch + 1, avg_train_loss, avg_test_loss])
|
194 |
+
history = np.vstack([history, item])
|
195 |
+
|
196 |
+
return history
|
models/unet.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
def one_param(m: nn.Module) -> nn.Parameter:
|
7 |
+
"""
|
8 |
+
get model first parameter
|
9 |
+
"""
|
10 |
+
return next(iter(m.parameters()))
|
11 |
+
|
12 |
+
|
13 |
+
class SelfAttention(nn.Module):
|
14 |
+
def __init__(self, channels):
|
15 |
+
super(SelfAttention, self).__init__()
|
16 |
+
self.channels = channels
|
17 |
+
self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
|
18 |
+
self.ln = nn.LayerNorm([channels])
|
19 |
+
self.ff_self = nn.Sequential(
|
20 |
+
nn.LayerNorm([channels]),
|
21 |
+
nn.Linear(channels, channels),
|
22 |
+
nn.GELU(),
|
23 |
+
nn.Linear(channels, channels),
|
24 |
+
)
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
size = x.shape[-1]
|
28 |
+
x = x.view(-1, self.channels, size * size).swapaxes(1, 2)
|
29 |
+
x_ln = self.ln(x)
|
30 |
+
attention_value, _ = self.mha(x_ln, x_ln, x_ln)
|
31 |
+
attention_value = attention_value + x
|
32 |
+
attention_value = self.ff_self(attention_value) + attention_value
|
33 |
+
return attention_value.swapaxes(2, 1).view(-1, self.channels, size, size)
|
34 |
+
|
35 |
+
|
36 |
+
class DoubleConv(nn.Module):
|
37 |
+
def __init__(self, in_channels, out_channels, mid_channels=None, residual=False):
|
38 |
+
super().__init__()
|
39 |
+
self.residual = residual
|
40 |
+
if not mid_channels:
|
41 |
+
mid_channels = out_channels
|
42 |
+
self.double_conv = nn.Sequential(
|
43 |
+
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
|
44 |
+
nn.GroupNorm(1, mid_channels),
|
45 |
+
nn.GELU(),
|
46 |
+
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
|
47 |
+
nn.GroupNorm(1, out_channels),
|
48 |
+
)
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
if self.residual:
|
52 |
+
return F.gelu(x + self.double_conv(x))
|
53 |
+
else:
|
54 |
+
return self.double_conv(x)
|
55 |
+
|
56 |
+
|
57 |
+
class Down(nn.Module):
|
58 |
+
def __init__(self, in_channels, out_channels, emb_dim=256):
|
59 |
+
super().__init__()
|
60 |
+
self.maxpool_conv = nn.Sequential(
|
61 |
+
nn.MaxPool2d(2),
|
62 |
+
DoubleConv(in_channels, in_channels, residual=True),
|
63 |
+
DoubleConv(in_channels, out_channels),
|
64 |
+
)
|
65 |
+
|
66 |
+
self.emb_layer = nn.Sequential(
|
67 |
+
nn.SiLU(),
|
68 |
+
nn.Linear(emb_dim, out_channels),
|
69 |
+
)
|
70 |
+
|
71 |
+
def forward(self, x, t):
|
72 |
+
x = self.maxpool_conv(x)
|
73 |
+
emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
|
74 |
+
return x + emb
|
75 |
+
|
76 |
+
|
77 |
+
class Up(nn.Module):
|
78 |
+
def __init__(self, in_channels, out_channels, emb_dim=256):
|
79 |
+
super().__init__()
|
80 |
+
|
81 |
+
self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
|
82 |
+
self.conv = nn.Sequential(
|
83 |
+
DoubleConv(in_channels, in_channels, residual=True),
|
84 |
+
DoubleConv(in_channels, out_channels, in_channels // 2),
|
85 |
+
)
|
86 |
+
|
87 |
+
self.emb_layer = nn.Sequential(
|
88 |
+
nn.SiLU(),
|
89 |
+
nn.Linear(emb_dim, out_channels),
|
90 |
+
)
|
91 |
+
|
92 |
+
def forward(self, x, skip_x, t):
|
93 |
+
x = self.up(x)
|
94 |
+
x = torch.cat([skip_x, x], dim=1)
|
95 |
+
x = self.conv(x)
|
96 |
+
emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
|
97 |
+
return x + emb
|
98 |
+
|
99 |
+
|
100 |
+
# UNetの定義
|
101 |
+
class UNet(nn.Module):
|
102 |
+
def __init__(self, c_in=3, c_out=3, time_dim=256, remove_deep_conv=False):
|
103 |
+
super().__init__()
|
104 |
+
self.time_dim = time_dim
|
105 |
+
self.remove_deep_conv = remove_deep_conv
|
106 |
+
self.inc = DoubleConv(c_in, 64)
|
107 |
+
self.down1 = Down(64, 128)
|
108 |
+
self.sa1 = SelfAttention(128)
|
109 |
+
self.down2 = Down(128, 256)
|
110 |
+
self.sa2 = SelfAttention(256)
|
111 |
+
self.down3 = Down(256, 256)
|
112 |
+
self.sa3 = SelfAttention(256)
|
113 |
+
|
114 |
+
if remove_deep_conv:
|
115 |
+
self.bot1 = DoubleConv(256, 256)
|
116 |
+
self.bot3 = DoubleConv(256, 256)
|
117 |
+
else:
|
118 |
+
self.bot1 = DoubleConv(256, 512)
|
119 |
+
self.bot2 = DoubleConv(512, 512)
|
120 |
+
self.bot3 = DoubleConv(512, 256)
|
121 |
+
|
122 |
+
self.up1 = Up(512, 128)
|
123 |
+
self.sa4 = SelfAttention(128)
|
124 |
+
self.up2 = Up(256, 64)
|
125 |
+
self.sa5 = SelfAttention(64)
|
126 |
+
self.up3 = Up(128, 64)
|
127 |
+
self.sa6 = SelfAttention(64)
|
128 |
+
self.outc = nn.Conv2d(64, c_out, kernel_size=1)
|
129 |
+
|
130 |
+
def pos_encoding(self, t, channels):
|
131 |
+
inv_freq = 1.0 / (
|
132 |
+
10000
|
133 |
+
** (
|
134 |
+
torch.arange(0, channels, 2, device=one_param(self).device).float()
|
135 |
+
/ channels
|
136 |
+
)
|
137 |
+
)
|
138 |
+
pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
|
139 |
+
pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
|
140 |
+
pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
|
141 |
+
return pos_enc
|
142 |
+
|
143 |
+
def unet_forwad(self, x, t):
|
144 |
+
x1 = self.inc(x)
|
145 |
+
x2 = self.down1(x1, t)
|
146 |
+
x2 = self.sa1(x2)
|
147 |
+
x3 = self.down2(x2, t)
|
148 |
+
x3 = self.sa2(x3)
|
149 |
+
x4 = self.down3(x3, t)
|
150 |
+
x4 = self.sa3(x4)
|
151 |
+
|
152 |
+
x4 = self.bot1(x4)
|
153 |
+
if not self.remove_deep_conv:
|
154 |
+
x4 = self.bot2(x4)
|
155 |
+
x4 = self.bot3(x4)
|
156 |
+
|
157 |
+
x = self.up1(x4, x3, t)
|
158 |
+
x = self.sa4(x)
|
159 |
+
x = self.up2(x, x2, t)
|
160 |
+
x = self.sa5(x)
|
161 |
+
x = self.up3(x, x1, t)
|
162 |
+
x = self.sa6(x)
|
163 |
+
output = self.outc(x)
|
164 |
+
return output
|
165 |
+
|
166 |
+
def forward(self, x, t):
|
167 |
+
t = t.unsqueeze(-1)
|
168 |
+
t = self.pos_encoding(t, self.time_dim)
|
169 |
+
return self.unet_forwad(x, t)
|
170 |
+
|
171 |
+
|
172 |
+
class UNet_conditional(UNet):
|
173 |
+
def __init__(self, c_in=3, c_out=3, time_dim=256, num_classes=46, **kwargs):
|
174 |
+
super().__init__(c_in, c_out, time_dim, **kwargs)
|
175 |
+
self.label_emb = nn.Embedding(num_classes, time_dim)
|
176 |
+
|
177 |
+
def forward(self, x, t, y):
|
178 |
+
"""
|
179 |
+
ラベルの埋め込みをタイムステップの埋め込みに加算
|
180 |
+
"""
|
181 |
+
t = t.unsqueeze(-1)
|
182 |
+
t = self.pos_encoding(t, self.time_dim)
|
183 |
+
|
184 |
+
if y is not None:
|
185 |
+
t += self.label_emb(y)
|
186 |
+
|
187 |
+
return self.unet_forwad(x, t)
|
requirements.txt
CHANGED
@@ -1 +1,4 @@
|
|
1 |
streamlit==1.40.1
|
|
|
|
|
|
|
|
1 |
streamlit==1.40.1
|
2 |
+
streamlit-drawable-canvas==0.9.3
|
3 |
+
torch==2.5.1
|
4 |
+
torchvision==0.20.1
|
utils.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
|
6 |
+
def initialize_data_dir() -> str:
|
7 |
+
if "data_dir" not in os.environ:
|
8 |
+
data_dir = tempfile.mkdtemp()
|
9 |
+
os.environ["data_dir"] = data_dir
|
10 |
+
return os.environ["data_dir"]
|
11 |
+
|
12 |
+
|
13 |
+
def save_image(img, char, idx, data_dir):
|
14 |
+
char_dir = os.path.join(data_dir, char)
|
15 |
+
os.makedirs(char_dir, exist_ok=True)
|
16 |
+
img_path = os.path.join(char_dir, f"{idx}.png")
|
17 |
+
img.save(img_path)
|
18 |
+
|
19 |
+
|
20 |
+
def load_images(characters, num_samples, data_dir, transform):
|
21 |
+
X = []
|
22 |
+
y = []
|
23 |
+
for label, char in enumerate(characters):
|
24 |
+
char_dir = os.path.join(data_dir, char)
|
25 |
+
if not os.path.exists(char_dir):
|
26 |
+
raise FileNotFoundError(
|
27 |
+
f"文字「{char}」の画像が不足しています。全てのサンプルを描いてください。"
|
28 |
+
)
|
29 |
+
for i in range(num_samples):
|
30 |
+
img_path = os.path.join(char_dir, f"{i}.png")
|
31 |
+
if not os.path.exists(img_path):
|
32 |
+
raise FileNotFoundError(
|
33 |
+
f"文字「{char}」のサンプル{i+1}が存在しません。"
|
34 |
+
)
|
35 |
+
img = Image.open(img_path).convert("L")
|
36 |
+
img = transform(img)
|
37 |
+
X.append(img)
|
38 |
+
y.append(label)
|
39 |
+
return X, y
|