matsudatkm commited on
Commit
6f21460
·
1 Parent(s): 9586f4c

deploy mock

Browse files
Files changed (8) hide show
  1. .gitignore +3 -0
  2. README.md +1 -0
  3. app.py +138 -2
  4. models/__init__.py +0 -0
  5. models/diffusion.py +196 -0
  6. models/unet.py +187 -0
  7. requirements.txt +3 -0
  8. utils.py +39 -0
.gitignore CHANGED
@@ -1,2 +1,5 @@
1
  .venv*
2
  __pycache__*
 
 
 
 
1
  .venv*
2
  __pycache__*
3
+
4
+ # large files
5
+ models/model.pth
README.md CHANGED
@@ -5,6 +5,7 @@ colorFrom: red
5
  colorTo: purple
6
  sdk: streamlit
7
  sdk_version: 1.40.1
 
8
  app_file: app.py
9
  pinned: false
10
  ---
 
5
  colorTo: purple
6
  sdk: streamlit
7
  sdk_version: 1.40.1
8
+ python_version: 3.10.15
9
  app_file: app.py
10
  pinned: false
11
  ---
app.py CHANGED
@@ -1,4 +1,140 @@
 
 
 
1
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- st.title("This is a title")
4
- st.write("This is a paragraph")
 
 
 
 
 
 
1
+ import time
2
+ import traceback
3
+
4
  import streamlit as st
5
+ import torch
6
+ from PIL import Image
7
+ from streamlit_drawable_canvas import st_canvas
8
+
9
+ from models.diffusion import Diffusion
10
+ from utils import initialize_data_dir, save_image
11
+
12
+ # 設定
13
+ st.set_page_config(page_title="手書き文字生成アプリ", layout="wide")
14
+
15
+ # タイトル
16
+ st.title("手書き文字生成アプリ")
17
+
18
+ # 説明
19
+ st.markdown(
20
+ """
21
+ このアプリでは、「あ」「い」「う」「え」「お」の手書き文字をそれぞれ5回ずつ描いてください。
22
+ 「生成」ボタンを押すと、モデルがファインチューニングされ、生成された文字が表示されます。
23
+ 学習率、エポック数、最適化手法も調整可能です。
24
+ """
25
+ )
26
+
27
+ # 文字リスト
28
+ characters = ["あ", "い", "う", "え", "お"]
29
+ num_samples = 5 # 各文字のサンプル数
30
+
31
+ # 描画スペースの配置
32
+ st.header("手書き文字を描いてください")
33
+
34
+ # 保存用ディレクトリ
35
+ data_dir = initialize_data_dir()
36
+
37
+ # 描画領域の作成
38
+ for char in characters:
39
+ st.subheader(f"文字「{char}」を{num_samples}回描いてください")
40
+ cols = st.columns(num_samples)
41
+ for i in range(num_samples):
42
+ with cols[i]:
43
+ canvas = st_canvas(
44
+ fill_color="white",
45
+ stroke_width=3,
46
+ stroke_color="black",
47
+ background_color="white",
48
+ width=150,
49
+ height=150,
50
+ drawing_mode="freedraw",
51
+ key=f"{char}_{i}",
52
+ )
53
+ if canvas.image_data is not None:
54
+ img = Image.fromarray(
55
+ canvas.image_data.astype("uint8"), "RGBA"
56
+ ).convert("L")
57
+ # 二値化
58
+ img = img.point(lambda x: 0 if x < 128 else 255, "1")
59
+ save_image(img, char, i, data_dir)
60
+
61
+ # ハイパーパラメータの入力
62
+ st.sidebar.header("ハイパーパラメータ設定")
63
+ learning_rate = st.sidebar.number_input(
64
+ "学習率", min_value=0.0001, max_value=1.0, value=0.001, step=0.0001, format="%.4f"
65
+ )
66
+ epochs = st.sidebar.number_input(
67
+ "エポック数", min_value=1, max_value=100, value=10, step=1
68
+ )
69
+ optimizer_name = st.sidebar.selectbox("最適化手法", ["SGD", "Adam", "RMSprop"])
70
+
71
+
72
+ # サンプリングの設定
73
+ st.sidebar.header("サンプリング設定")
74
+ noise_steps = st.sidebar.number_input(
75
+ "ノイズステップ数", min_value=1, max_value=1000, value=1000, step=1
76
+ )
77
+ beta_start = st.sidebar.number_input(
78
+ "βの初期値", min_value=0.0, max_value=1.0, value=0.0001, step=0.0001, format="%.4f"
79
+ )
80
+ beta_end = st.sidebar.number_input(
81
+ "βの終了値", min_value=0.0, max_value=1.0, value=0.02, step=0.0001, format="%.4f"
82
+ )
83
+
84
+
85
+ # 生成ボタン
86
+ if st.button("生成"):
87
+ try:
88
+ # デバイスの設定
89
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
90
+
91
+ # モデルインスタンスの生成
92
+ diffusion_model = Diffusion(
93
+ noise_steps=1000,
94
+ beta_start=1e-4,
95
+ beta_end=0.02,
96
+ img_size=32,
97
+ num_classes=46,
98
+ c_in=3,
99
+ c_out=3,
100
+ device=device,
101
+ )
102
+
103
+ save_path = "models/model.pth"
104
+
105
+ # モデルのロード
106
+ diffusion_model.model.load_state_dict(
107
+ torch.load(save_path, weights_only=True, map_location=device)
108
+ )
109
+
110
+ # フォント画像の生成と表示
111
+ def chuncked(iterable, n):
112
+ for i in range(0, len(iterable), n):
113
+ yield iterable[i : i + n]
114
+
115
+ labels = list(range(46))
116
+ columns_per_row = 5
117
+ start_time = time.time()
118
+ with st.spinner("フォント画像を生成中..."):
119
+ labels_tensor = torch.tensor(labels).to(device)
120
+ font_image = diffusion_model.sample(
121
+ diffusion_model.model, labels_tensor
122
+ )
123
+ elapsed_time = time.time() - start_time
124
+ st.success(f"フォント画像の生成に成功しました({elapsed_time:.2f}秒)")
125
+ for label_row in chuncked(labels, columns_per_row):
126
+ cols = st.columns(columns_per_row)
127
+ for col, label in zip(cols, label_row):
128
+ col.image(
129
+ font_image[label].permute(1, 2, 0).cpu().numpy(),
130
+ caption=f"{label}",
131
+ use_container_width=True,
132
+ )
133
 
134
+ except FileNotFoundError as e:
135
+ st.error(str(e))
136
+ except ValueError as e:
137
+ st.error(str(e))
138
+ except Exception as e:
139
+ st.error(f"予期せぬエラーが発生しました: {e}")
140
+ st.error(traceback.format_exc())
models/__init__.py ADDED
File without changes
models/diffusion.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from .unet import UNet_conditional
4
+
5
+
6
+ class Diffusion:
7
+ def __init__(
8
+ self,
9
+ noise_steps: int = 1000,
10
+ beta_start: float = 1e-4,
11
+ beta_end: float = 0.02,
12
+ img_size: int = 32,
13
+ num_classes: int = 46,
14
+ c_in: int = 3,
15
+ c_out: int = 3,
16
+ device: torch.device = torch.device("cuda"),
17
+ time_dim: int = 256,
18
+ **kwargs,
19
+ ):
20
+
21
+ self.noise_steps = noise_steps
22
+ self.beta_start = beta_start
23
+ self.beta_end = beta_end
24
+ self.img_size = img_size
25
+ self.device = device
26
+ self.time_dim = time_dim
27
+ self.num_classes = num_classes
28
+ self.c_in = c_in
29
+ self.c_out = c_out
30
+ self.model = UNet_conditional(
31
+ c_in, c_out, time_dim, num_classes=num_classes, **kwargs
32
+ ).to(device)
33
+
34
+ self.beta = self.prepare_noise_schedule().to(device)
35
+ self.alpha = 1.0 - self.beta
36
+ self.alpha_hat = torch.cumprod(self.alpha, dim=0)
37
+
38
+ def __call__(self, x, t, labels):
39
+ return self.model(x, t, labels)
40
+
41
+ def prepare_noise_schedule(self) -> torch.Tensor:
42
+ """
43
+ ノイズスケジュールの生成
44
+ """
45
+ return torch.linspace(self.beta_start, self.beta_end, self.noise_steps)
46
+
47
+ def noise_images(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
48
+ """
49
+ ノイズ画像とノイズの生成
50
+ """
51
+ sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
52
+ sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
53
+ noise = torch.randn_like(x)
54
+ return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * noise, noise
55
+
56
+ def sample_timesteps(self, n: int) -> torch.Tensor:
57
+ """
58
+ タイムステップのサンプリング
59
+ """
60
+ return torch.randint(low=1, high=self.noise_steps, size=(n,))
61
+
62
+ def sample(self, model: torch.nn.Module, labels: torch.Tensor) -> torch.Tensor:
63
+ """
64
+ 画像の生成
65
+ """
66
+ self.model = model
67
+ n = len(labels)
68
+ print(f"Sampling {n} new images....")
69
+ model.eval()
70
+
71
+ with torch.no_grad():
72
+ x = torch.randn((n, self.c_in, self.img_size, self.img_size)).to(
73
+ self.device
74
+ )
75
+ for i in reversed(range(1, self.noise_steps)):
76
+ t = (torch.ones(n) * i).long().to(self.device)
77
+ predicted_noise = model(x, t, labels)
78
+ alpha = self.alpha[t][:, None, None, None]
79
+ alpha_hat = self.alpha_hat[t][:, None, None, None]
80
+ beta = self.beta[t][:, None, None, None]
81
+ if i > 1:
82
+ noise = torch.randn_like(x)
83
+ else:
84
+ noise = torch.zeros_like(x)
85
+ x = (
86
+ 1
87
+ / torch.sqrt(alpha)
88
+ * (
89
+ x
90
+ - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise
91
+ )
92
+ + torch.sqrt(beta) * noise
93
+ )
94
+ model.train()
95
+ x = (x.clamp(-1, 1) + 1) / 2
96
+ x = (x * 255).type(torch.uint8)
97
+ return x
98
+
99
+ def fit(
100
+ self,
101
+ optimizer: torch.optim.Optimizer,
102
+ criterion: torch.nn.Module,
103
+ num_epochs: int,
104
+ train_loader: torch.utils.data.DataLoader,
105
+ test_loader: torch.utils.data.DataLoader,
106
+ model: torch.nn.Module,
107
+ device: torch.device,
108
+ history: np.ndarray,
109
+ save_path: str,
110
+ ) -> np.ndarray:
111
+ """
112
+ モデルの学習
113
+ """
114
+ base_epochs = len(history)
115
+ # 最小損失の初期化
116
+ min_test_loss = 9e9
117
+
118
+ for epoch in range(base_epochs, base_epochs + num_epochs):
119
+ # 1エポックあたりの累積損失(平均化前)
120
+ train_loss, test_loss = 0, 0
121
+ # 1エポックあたりのデータ累積件数
122
+ n_train, n_test = 0, 0
123
+
124
+ # 訓練フェーズ
125
+ self.model.train()
126
+ for x, labels in train_loader:
127
+ # 1バッチあたりのデータ件数
128
+ train_batch_size = len(labels)
129
+ # 1エポックあたりのデータ累積件数
130
+ n_train += train_batch_size
131
+
132
+ # GPUに転送
133
+ x = x.to(device)
134
+ labels = labels.to(device)
135
+
136
+ # ノイズステップを生成
137
+ t = self.sample_timesteps(x.size(0)).to(device)
138
+ # ノイズ画像とノイズを生成
139
+ xt, noise = self.noise_images(x, t)
140
+
141
+ # 勾配の初期化
142
+ optimizer.zero_grad()
143
+ # 予測計算
144
+ predicted_noise = model(xt, t, labels)
145
+ # 損失計算
146
+ loss = criterion(predicted_noise, noise)
147
+ # 勾配計算
148
+ loss.backward()
149
+ # パラメータ修正
150
+ optimizer.step()
151
+
152
+ # 平均前の損失計算
153
+ train_loss += loss.item() * train_batch_size
154
+
155
+ # 予測フェーズ
156
+ self.model.eval()
157
+ for x, labels in test_loader:
158
+ # 1バッチあたりのデータ件数
159
+ test_batch_size = len(labels)
160
+ # 1エポックあたりのデータ累積件数
161
+ n_test += test_batch_size
162
+
163
+ # GPUに転送
164
+ x = x.to(device)
165
+ labels = labels.to(device)
166
+
167
+ # ノイズステップを生成
168
+ t = self.sample_timesteps(x.size(0)).to(device)
169
+ # ノイズ画像とノイズを生成
170
+ xt, noise = self.noise_images(x, t)
171
+ # 予測計算
172
+ predicted_noise = model(xt, t, labels)
173
+ # 損失計算
174
+ loss = criterion(predicted_noise, noise)
175
+
176
+ # 平均前の損失計算
177
+ test_loss += loss.item() * test_batch_size
178
+
179
+ # 損失計算
180
+ avg_train_loss = train_loss / n_train
181
+ avg_test_loss = test_loss / n_test
182
+
183
+ # 最小損失の更新とモデルの保存
184
+ if avg_test_loss < min_test_loss:
185
+ min_test_loss = avg_test_loss
186
+ torch.save(self.model.state_dict(), save_path)
187
+
188
+ # 結果表示
189
+ print(
190
+ f"Epoch {epoch + 1}, Train loss: {avg_train_loss:.3f}, Test loss: {avg_test_loss:.3f}"
191
+ )
192
+ # 記録
193
+ item = np.array([epoch + 1, avg_train_loss, avg_test_loss])
194
+ history = np.vstack([history, item])
195
+
196
+ return history
models/unet.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def one_param(m: nn.Module) -> nn.Parameter:
7
+ """
8
+ get model first parameter
9
+ """
10
+ return next(iter(m.parameters()))
11
+
12
+
13
+ class SelfAttention(nn.Module):
14
+ def __init__(self, channels):
15
+ super(SelfAttention, self).__init__()
16
+ self.channels = channels
17
+ self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
18
+ self.ln = nn.LayerNorm([channels])
19
+ self.ff_self = nn.Sequential(
20
+ nn.LayerNorm([channels]),
21
+ nn.Linear(channels, channels),
22
+ nn.GELU(),
23
+ nn.Linear(channels, channels),
24
+ )
25
+
26
+ def forward(self, x):
27
+ size = x.shape[-1]
28
+ x = x.view(-1, self.channels, size * size).swapaxes(1, 2)
29
+ x_ln = self.ln(x)
30
+ attention_value, _ = self.mha(x_ln, x_ln, x_ln)
31
+ attention_value = attention_value + x
32
+ attention_value = self.ff_self(attention_value) + attention_value
33
+ return attention_value.swapaxes(2, 1).view(-1, self.channels, size, size)
34
+
35
+
36
+ class DoubleConv(nn.Module):
37
+ def __init__(self, in_channels, out_channels, mid_channels=None, residual=False):
38
+ super().__init__()
39
+ self.residual = residual
40
+ if not mid_channels:
41
+ mid_channels = out_channels
42
+ self.double_conv = nn.Sequential(
43
+ nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
44
+ nn.GroupNorm(1, mid_channels),
45
+ nn.GELU(),
46
+ nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
47
+ nn.GroupNorm(1, out_channels),
48
+ )
49
+
50
+ def forward(self, x):
51
+ if self.residual:
52
+ return F.gelu(x + self.double_conv(x))
53
+ else:
54
+ return self.double_conv(x)
55
+
56
+
57
+ class Down(nn.Module):
58
+ def __init__(self, in_channels, out_channels, emb_dim=256):
59
+ super().__init__()
60
+ self.maxpool_conv = nn.Sequential(
61
+ nn.MaxPool2d(2),
62
+ DoubleConv(in_channels, in_channels, residual=True),
63
+ DoubleConv(in_channels, out_channels),
64
+ )
65
+
66
+ self.emb_layer = nn.Sequential(
67
+ nn.SiLU(),
68
+ nn.Linear(emb_dim, out_channels),
69
+ )
70
+
71
+ def forward(self, x, t):
72
+ x = self.maxpool_conv(x)
73
+ emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
74
+ return x + emb
75
+
76
+
77
+ class Up(nn.Module):
78
+ def __init__(self, in_channels, out_channels, emb_dim=256):
79
+ super().__init__()
80
+
81
+ self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
82
+ self.conv = nn.Sequential(
83
+ DoubleConv(in_channels, in_channels, residual=True),
84
+ DoubleConv(in_channels, out_channels, in_channels // 2),
85
+ )
86
+
87
+ self.emb_layer = nn.Sequential(
88
+ nn.SiLU(),
89
+ nn.Linear(emb_dim, out_channels),
90
+ )
91
+
92
+ def forward(self, x, skip_x, t):
93
+ x = self.up(x)
94
+ x = torch.cat([skip_x, x], dim=1)
95
+ x = self.conv(x)
96
+ emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
97
+ return x + emb
98
+
99
+
100
+ # UNetの定義
101
+ class UNet(nn.Module):
102
+ def __init__(self, c_in=3, c_out=3, time_dim=256, remove_deep_conv=False):
103
+ super().__init__()
104
+ self.time_dim = time_dim
105
+ self.remove_deep_conv = remove_deep_conv
106
+ self.inc = DoubleConv(c_in, 64)
107
+ self.down1 = Down(64, 128)
108
+ self.sa1 = SelfAttention(128)
109
+ self.down2 = Down(128, 256)
110
+ self.sa2 = SelfAttention(256)
111
+ self.down3 = Down(256, 256)
112
+ self.sa3 = SelfAttention(256)
113
+
114
+ if remove_deep_conv:
115
+ self.bot1 = DoubleConv(256, 256)
116
+ self.bot3 = DoubleConv(256, 256)
117
+ else:
118
+ self.bot1 = DoubleConv(256, 512)
119
+ self.bot2 = DoubleConv(512, 512)
120
+ self.bot3 = DoubleConv(512, 256)
121
+
122
+ self.up1 = Up(512, 128)
123
+ self.sa4 = SelfAttention(128)
124
+ self.up2 = Up(256, 64)
125
+ self.sa5 = SelfAttention(64)
126
+ self.up3 = Up(128, 64)
127
+ self.sa6 = SelfAttention(64)
128
+ self.outc = nn.Conv2d(64, c_out, kernel_size=1)
129
+
130
+ def pos_encoding(self, t, channels):
131
+ inv_freq = 1.0 / (
132
+ 10000
133
+ ** (
134
+ torch.arange(0, channels, 2, device=one_param(self).device).float()
135
+ / channels
136
+ )
137
+ )
138
+ pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
139
+ pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
140
+ pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
141
+ return pos_enc
142
+
143
+ def unet_forwad(self, x, t):
144
+ x1 = self.inc(x)
145
+ x2 = self.down1(x1, t)
146
+ x2 = self.sa1(x2)
147
+ x3 = self.down2(x2, t)
148
+ x3 = self.sa2(x3)
149
+ x4 = self.down3(x3, t)
150
+ x4 = self.sa3(x4)
151
+
152
+ x4 = self.bot1(x4)
153
+ if not self.remove_deep_conv:
154
+ x4 = self.bot2(x4)
155
+ x4 = self.bot3(x4)
156
+
157
+ x = self.up1(x4, x3, t)
158
+ x = self.sa4(x)
159
+ x = self.up2(x, x2, t)
160
+ x = self.sa5(x)
161
+ x = self.up3(x, x1, t)
162
+ x = self.sa6(x)
163
+ output = self.outc(x)
164
+ return output
165
+
166
+ def forward(self, x, t):
167
+ t = t.unsqueeze(-1)
168
+ t = self.pos_encoding(t, self.time_dim)
169
+ return self.unet_forwad(x, t)
170
+
171
+
172
+ class UNet_conditional(UNet):
173
+ def __init__(self, c_in=3, c_out=3, time_dim=256, num_classes=46, **kwargs):
174
+ super().__init__(c_in, c_out, time_dim, **kwargs)
175
+ self.label_emb = nn.Embedding(num_classes, time_dim)
176
+
177
+ def forward(self, x, t, y):
178
+ """
179
+ ラベルの埋め込みをタイムステップの埋め込みに加算
180
+ """
181
+ t = t.unsqueeze(-1)
182
+ t = self.pos_encoding(t, self.time_dim)
183
+
184
+ if y is not None:
185
+ t += self.label_emb(y)
186
+
187
+ return self.unet_forwad(x, t)
requirements.txt CHANGED
@@ -1 +1,4 @@
1
  streamlit==1.40.1
 
 
 
 
1
  streamlit==1.40.1
2
+ streamlit-drawable-canvas==0.9.3
3
+ torch==2.5.1
4
+ torchvision==0.20.1
utils.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ from PIL import Image
4
+
5
+
6
+ def initialize_data_dir() -> str:
7
+ if "data_dir" not in os.environ:
8
+ data_dir = tempfile.mkdtemp()
9
+ os.environ["data_dir"] = data_dir
10
+ return os.environ["data_dir"]
11
+
12
+
13
+ def save_image(img, char, idx, data_dir):
14
+ char_dir = os.path.join(data_dir, char)
15
+ os.makedirs(char_dir, exist_ok=True)
16
+ img_path = os.path.join(char_dir, f"{idx}.png")
17
+ img.save(img_path)
18
+
19
+
20
+ def load_images(characters, num_samples, data_dir, transform):
21
+ X = []
22
+ y = []
23
+ for label, char in enumerate(characters):
24
+ char_dir = os.path.join(data_dir, char)
25
+ if not os.path.exists(char_dir):
26
+ raise FileNotFoundError(
27
+ f"文字「{char}」の画像が不足しています。全てのサンプルを描いてください。"
28
+ )
29
+ for i in range(num_samples):
30
+ img_path = os.path.join(char_dir, f"{i}.png")
31
+ if not os.path.exists(img_path):
32
+ raise FileNotFoundError(
33
+ f"文字「{char}」のサンプル{i+1}が存在しません。"
34
+ )
35
+ img = Image.open(img_path).convert("L")
36
+ img = transform(img)
37
+ X.append(img)
38
+ y.append(label)
39
+ return X, y