ryo2 commited on
Commit
88d65a7
·
verified ·
1 Parent(s): 3667fc7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +225 -0
app.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import os
4
+ import shutil
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import japanize_matplotlib
8
+
9
+
10
+ class DataProcessor:
11
+ def __init__(self, bodypart_names, x_max, y_max):
12
+ self.bodypart_names = bodypart_names.split(',')
13
+ self.x_max = x_max
14
+ self.y_max = y_max
15
+ self.output_folder = 'output_plots'
16
+
17
+ def process_csv(self, file_path):
18
+ df = pd.read_csv(file_path, header=[1, 2])
19
+ df_likelihood = self.extract_likelihood(df)
20
+ df = self.remove_first_column_and_likelihood(df)
21
+ df = self.rename_bodyparts(df)
22
+ return df, df_likelihood
23
+
24
+ def remove_first_column_and_likelihood(self, df):
25
+ df = df.drop(df.columns[0], axis=1)
26
+ df = df[df.columns.drop(list(df.filter(regex='likelihood')))]
27
+ return df
28
+
29
+ def rename_bodyparts(self, df):
30
+ current_names = df.columns.get_level_values(0).unique()
31
+ if len(self.bodypart_names) != len(current_names):
32
+ raise ValueError(
33
+ "The length of bodypart_names must be equal to the number of bodyparts.")
34
+ mapping = dict(zip(current_names, self.bodypart_names))
35
+ new_columns = [(mapping[col[0]], col[1]) if col[0]
36
+ in mapping else col for col in df.columns]
37
+ df.columns = pd.MultiIndex.from_tuples(new_columns)
38
+ return df
39
+
40
+ def extract_likelihood(self, df):
41
+ # likelihood列のみを抽出する
42
+ df = df[df.columns[df.columns.get_level_values(1) == 'likelihood']]
43
+ df.drop(df.columns[0], axis=1)
44
+ current_names = df.columns.get_level_values(0).unique()
45
+ mapping = dict(zip(current_names, self.bodypart_names))
46
+ new_columns = [(mapping[col[0]], col[1]) if col[0]
47
+ in mapping else col for col in df.columns]
48
+ df.columns = pd.MultiIndex.from_tuples(new_columns)
49
+
50
+ return df
51
+
52
+ def plot_scatter(self, df):
53
+ if not os.path.exists(self.output_folder):
54
+ os.makedirs(self.output_folder)
55
+ return self.plot_scatter_fixed(df, self.output_folder, self.x_max, self.y_max)
56
+
57
+ def plot_scatter_fixed(self, df, output_folder, x_max, y_max):
58
+ image_paths = []
59
+ bodyparts = df.columns.get_level_values(0).unique()
60
+ colors = plt.cm.rainbow(np.linspace(0, 1, len(bodyparts)))
61
+ for i, bodypart in enumerate(bodyparts):
62
+ x = df[bodypart]['x'].values
63
+ y = df[bodypart]['y'].values
64
+ plt.figure(figsize=(8, 6))
65
+ plt.scatter(x, y, color=colors[i], label=bodypart)
66
+ plt.scatter(x[0], y[0], color='black', marker='o', s=100)
67
+ plt.text(x[0], y[0], ' Start', color='black',
68
+ fontsize=12, verticalalignment='bottom')
69
+ plt.xlim(0, x_max)
70
+ plt.ylim(0, y_max)
71
+ plt.gca().invert_yaxis()
72
+ plt.title(f'トラッキングの座標({bodypart})')
73
+ plt.xlabel('X Coordinate(pixel)')
74
+ plt.ylabel('Y Coordinate(pixel)')
75
+ plt.legend(loc='upper right')
76
+ plt.savefig(f'{output_folder}/{bodypart}.png')
77
+ image_paths.append(f'{output_folder}/{bodypart}.png')
78
+ plt.close()
79
+
80
+ plt.figure(figsize=(8, 6))
81
+ for i, bodypart in enumerate(bodyparts):
82
+ x = df[bodypart]['x'].values
83
+ y = df[bodypart]['y'].values
84
+ plt.scatter(x, y, color=colors[i], label=bodypart)
85
+ plt.xlim(0, x_max)
86
+ plt.ylim(0, y_max)
87
+ plt.gca().invert_yaxis()
88
+ plt.title('トラッキングの座標(全付属肢)')
89
+ plt.xlabel('X Coordinate(pixel)')
90
+ plt.ylabel('Y Coordinate(pixel)')
91
+ plt.legend(loc='upper right')
92
+ plt.savefig(f'{output_folder}/all_plot.png')
93
+ image_paths.append(f'{output_folder}/all_plot.png')
94
+ plt.close()
95
+
96
+ return image_paths
97
+
98
+ def plot_trajectories(self, df):
99
+ image_paths = []
100
+ bodyparts = df.columns.get_level_values(0).unique()
101
+ colors = plt.cm.rainbow(np.linspace(0, 1, len(bodyparts)))
102
+
103
+ for i, bodypart in enumerate(bodyparts):
104
+ x = df[bodypart]['x'].values
105
+ y = df[bodypart]['y'].values
106
+ plt.figure(figsize=(8, 6))
107
+ plt.plot(x, color=colors[i], label=bodypart +
108
+ "(x座標)", linestyle='dashed')
109
+ plt.plot(y, color=colors[i], label=bodypart + "(y座標)")
110
+ plt.title(f'トラッキングの座標({bodypart})')
111
+ plt.xlabel('Frames')
112
+ plt.ylabel('Coordinate(pixel)')
113
+ plt.legend(loc='upper right')
114
+ plt.savefig(f'{self.output_folder}/{bodypart}_trajectories.png')
115
+ image_paths.append(
116
+ f'{self.output_folder}/{bodypart}_trajectories.png')
117
+ plt.close()
118
+
119
+ plt.figure(figsize=(8, 6))
120
+ for i, bodypart in enumerate(bodyparts):
121
+ x = df[bodypart]['x'].values
122
+ y = df[bodypart]['y'].values
123
+ plt.plot(x, color=colors[i], label=bodypart +
124
+ "(x座標)", linestyle='dashed')
125
+ plt.plot(y, color=colors[i], label=bodypart + "(y座標)")
126
+ plt.title(f'トラッキングの座標({bodypart})')
127
+ plt.xlabel('Frames')
128
+ plt.ylabel('Coordinate(pixel)')
129
+ plt.legend(loc='upper right')
130
+ plt.savefig(f'{self.output_folder}/all_trajectories.png')
131
+ image_paths.append(f'{self.output_folder}/all_trajectories.png')
132
+ plt.close()
133
+ return image_paths
134
+
135
+ def plot_likelihood(self, likelihood_df):
136
+ image_paths = []
137
+ plt.ylim(0, 1.0)
138
+ bodyparts = likelihood_df.columns.get_level_values(0).unique()
139
+ colors = plt.cm.rainbow(np.linspace(0, 1, len(bodyparts)))
140
+
141
+ # 付属肢ごとに尤度をプロット
142
+ for i, bodypart in enumerate(bodyparts):
143
+ plt.figure(figsize=(8, 6))
144
+ plt.ylim(0, 1.0)
145
+ plt.plot(likelihood_df[bodypart], color=colors[i], label=bodypart)
146
+ plt.xlabel('Frames')
147
+ plt.ylabel('尤度')
148
+ plt.title('フレーム別の尤度')
149
+ # 凡例を右上の外側に表示
150
+ plt.legend(bbox_to_anchor=(1.05, 1),
151
+ loc='upper left', borderaxespad=0)
152
+ # 凡例がはみ出さないようにレイアウトを調整
153
+ plt.tight_layout()
154
+ plt.savefig(f'{self.output_folder}/{bodypart}_likelihood.png')
155
+ image_paths.append(f'{self.output_folder}/{bodypart}_likelihood.png')
156
+ plt.close()
157
+
158
+ # 全ての付属肢の尤度をプロット
159
+ plt.figure(figsize=(8, 6))
160
+ plt.ylim(0, 1.0)
161
+ for i, column in enumerate(likelihood_df.columns):
162
+ plt.plot(likelihood_df[column], color=colors[i], label=column[0])
163
+ plt.xlabel('Frames')
164
+ plt.ylabel('尤度')
165
+ plt.title('フレーム別の尤度')
166
+ # 凡例を右上の外側に表示
167
+ plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0)
168
+ # 凡例がはみ出さないようにレイアウトを調整
169
+ plt.tight_layout()
170
+ plt.savefig(f'{self.output_folder}/likelihood_plot.png')
171
+ plt.close()
172
+ image_paths.append(f'{self.output_folder}/likelihood_plot.png')
173
+ return image_paths
174
+
175
+ # 以下のGradioInterfaceクラスとメイン実行部分は変更なし
176
+
177
+
178
+ class GradioInterface:
179
+ def __init__(self):
180
+ self.interface = gr.Interface(
181
+ fn=self.process_and_plot,
182
+ inputs=[
183
+ gr.File(label="CSVファイルをドラッグ&ドロップ"),
184
+ gr.Textbox(label="付属肢の名前(カンマ区切り)",
185
+ value="指節1, 指節2, 指節3, 指節4, 指節5, 指節6, 指節7, 指節8, 指節9,指節10, 指節11, 指節12, 指節13, 指節14, 触角(左), 触角(右), 頭部, 腹尾節"),
186
+ gr.Number(label="X軸の最大値", value=1920),
187
+ gr.Number(label="Y軸の最大値", value=1080),
188
+ gr.CheckboxGroup(
189
+ label="プロットするグラフを選択",
190
+ choices=["散布図", "軌跡図", "尤度グラフ"],
191
+ value=["散布図", "軌跡図", "尤度グラフ"],
192
+ type="value"
193
+ )
194
+ ],
195
+ outputs=[
196
+ gr.Gallery(label="散布図"),
197
+ gr.File(label="ZIPダウンロード")
198
+ ],
199
+ title="DeepLabCutグラフ出力ツール",
200
+ description="CSVファイルからグラフを作成します。"
201
+ )
202
+
203
+ def process_and_plot(self, file, bodypart_names, x_max, y_max, graph_choices):
204
+ processor = DataProcessor(bodypart_names, x_max, y_max)
205
+ df, df_likelihood = processor.process_csv(file.name)
206
+
207
+ all_image_paths = []
208
+ if "散布図" in graph_choices:
209
+ all_image_paths += processor.plot_scatter(df)
210
+ if "軌跡図" in graph_choices:
211
+ all_image_paths += processor.plot_trajectories(df)
212
+ if "尤度グラフ" in graph_choices:
213
+ all_image_paths += processor.plot_likelihood(df_likelihood)
214
+
215
+ shutil.make_archive(processor.output_folder,
216
+ 'zip', processor.output_folder)
217
+ return all_image_paths, processor.output_folder + '.zip'
218
+
219
+ def launch(self):
220
+ self.interface.launch()
221
+
222
+
223
+ if __name__ == "__main__":
224
+ gradio_app = GradioInterface()
225
+ gradio_app.launch()