ryo2's picture
Create app.py
88d65a7 verified
raw
history blame
9.1 kB
import gradio as gr
import pandas as pd
import os
import shutil
import matplotlib.pyplot as plt
import numpy as np
import japanize_matplotlib
class DataProcessor:
def __init__(self, bodypart_names, x_max, y_max):
self.bodypart_names = bodypart_names.split(',')
self.x_max = x_max
self.y_max = y_max
self.output_folder = 'output_plots'
def process_csv(self, file_path):
df = pd.read_csv(file_path, header=[1, 2])
df_likelihood = self.extract_likelihood(df)
df = self.remove_first_column_and_likelihood(df)
df = self.rename_bodyparts(df)
return df, df_likelihood
def remove_first_column_and_likelihood(self, df):
df = df.drop(df.columns[0], axis=1)
df = df[df.columns.drop(list(df.filter(regex='likelihood')))]
return df
def rename_bodyparts(self, df):
current_names = df.columns.get_level_values(0).unique()
if len(self.bodypart_names) != len(current_names):
raise ValueError(
"The length of bodypart_names must be equal to the number of bodyparts.")
mapping = dict(zip(current_names, self.bodypart_names))
new_columns = [(mapping[col[0]], col[1]) if col[0]
in mapping else col for col in df.columns]
df.columns = pd.MultiIndex.from_tuples(new_columns)
return df
def extract_likelihood(self, df):
# likelihood列のみを抽出する
df = df[df.columns[df.columns.get_level_values(1) == 'likelihood']]
df.drop(df.columns[0], axis=1)
current_names = df.columns.get_level_values(0).unique()
mapping = dict(zip(current_names, self.bodypart_names))
new_columns = [(mapping[col[0]], col[1]) if col[0]
in mapping else col for col in df.columns]
df.columns = pd.MultiIndex.from_tuples(new_columns)
return df
def plot_scatter(self, df):
if not os.path.exists(self.output_folder):
os.makedirs(self.output_folder)
return self.plot_scatter_fixed(df, self.output_folder, self.x_max, self.y_max)
def plot_scatter_fixed(self, df, output_folder, x_max, y_max):
image_paths = []
bodyparts = df.columns.get_level_values(0).unique()
colors = plt.cm.rainbow(np.linspace(0, 1, len(bodyparts)))
for i, bodypart in enumerate(bodyparts):
x = df[bodypart]['x'].values
y = df[bodypart]['y'].values
plt.figure(figsize=(8, 6))
plt.scatter(x, y, color=colors[i], label=bodypart)
plt.scatter(x[0], y[0], color='black', marker='o', s=100)
plt.text(x[0], y[0], ' Start', color='black',
fontsize=12, verticalalignment='bottom')
plt.xlim(0, x_max)
plt.ylim(0, y_max)
plt.gca().invert_yaxis()
plt.title(f'トラッキングの座標({bodypart})')
plt.xlabel('X Coordinate(pixel)')
plt.ylabel('Y Coordinate(pixel)')
plt.legend(loc='upper right')
plt.savefig(f'{output_folder}/{bodypart}.png')
image_paths.append(f'{output_folder}/{bodypart}.png')
plt.close()
plt.figure(figsize=(8, 6))
for i, bodypart in enumerate(bodyparts):
x = df[bodypart]['x'].values
y = df[bodypart]['y'].values
plt.scatter(x, y, color=colors[i], label=bodypart)
plt.xlim(0, x_max)
plt.ylim(0, y_max)
plt.gca().invert_yaxis()
plt.title('トラッキングの座標(全付属肢)')
plt.xlabel('X Coordinate(pixel)')
plt.ylabel('Y Coordinate(pixel)')
plt.legend(loc='upper right')
plt.savefig(f'{output_folder}/all_plot.png')
image_paths.append(f'{output_folder}/all_plot.png')
plt.close()
return image_paths
def plot_trajectories(self, df):
image_paths = []
bodyparts = df.columns.get_level_values(0).unique()
colors = plt.cm.rainbow(np.linspace(0, 1, len(bodyparts)))
for i, bodypart in enumerate(bodyparts):
x = df[bodypart]['x'].values
y = df[bodypart]['y'].values
plt.figure(figsize=(8, 6))
plt.plot(x, color=colors[i], label=bodypart +
"(x座標)", linestyle='dashed')
plt.plot(y, color=colors[i], label=bodypart + "(y座標)")
plt.title(f'トラッキングの座標({bodypart})')
plt.xlabel('Frames')
plt.ylabel('Coordinate(pixel)')
plt.legend(loc='upper right')
plt.savefig(f'{self.output_folder}/{bodypart}_trajectories.png')
image_paths.append(
f'{self.output_folder}/{bodypart}_trajectories.png')
plt.close()
plt.figure(figsize=(8, 6))
for i, bodypart in enumerate(bodyparts):
x = df[bodypart]['x'].values
y = df[bodypart]['y'].values
plt.plot(x, color=colors[i], label=bodypart +
"(x座標)", linestyle='dashed')
plt.plot(y, color=colors[i], label=bodypart + "(y座標)")
plt.title(f'トラッキングの座標({bodypart})')
plt.xlabel('Frames')
plt.ylabel('Coordinate(pixel)')
plt.legend(loc='upper right')
plt.savefig(f'{self.output_folder}/all_trajectories.png')
image_paths.append(f'{self.output_folder}/all_trajectories.png')
plt.close()
return image_paths
def plot_likelihood(self, likelihood_df):
image_paths = []
plt.ylim(0, 1.0)
bodyparts = likelihood_df.columns.get_level_values(0).unique()
colors = plt.cm.rainbow(np.linspace(0, 1, len(bodyparts)))
# 付属肢ごとに尤度をプロット
for i, bodypart in enumerate(bodyparts):
plt.figure(figsize=(8, 6))
plt.ylim(0, 1.0)
plt.plot(likelihood_df[bodypart], color=colors[i], label=bodypart)
plt.xlabel('Frames')
plt.ylabel('尤度')
plt.title('フレーム別の尤度')
# 凡例を右上の外側に表示
plt.legend(bbox_to_anchor=(1.05, 1),
loc='upper left', borderaxespad=0)
# 凡例がはみ出さないようにレイアウトを調整
plt.tight_layout()
plt.savefig(f'{self.output_folder}/{bodypart}_likelihood.png')
image_paths.append(f'{self.output_folder}/{bodypart}_likelihood.png')
plt.close()
# 全ての付属肢の尤度をプロット
plt.figure(figsize=(8, 6))
plt.ylim(0, 1.0)
for i, column in enumerate(likelihood_df.columns):
plt.plot(likelihood_df[column], color=colors[i], label=column[0])
plt.xlabel('Frames')
plt.ylabel('尤度')
plt.title('フレーム別の尤度')
# 凡例を右上の外側に表示
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0)
# 凡例がはみ出さないようにレイアウトを調整
plt.tight_layout()
plt.savefig(f'{self.output_folder}/likelihood_plot.png')
plt.close()
image_paths.append(f'{self.output_folder}/likelihood_plot.png')
return image_paths
# 以下のGradioInterfaceクラスとメイン実行部分は変更なし
class GradioInterface:
def __init__(self):
self.interface = gr.Interface(
fn=self.process_and_plot,
inputs=[
gr.File(label="CSVファイルをドラッグ&ドロップ"),
gr.Textbox(label="付属肢の名前(カンマ区切り)",
value="指節1, 指節2, 指節3, 指節4, 指節5, 指節6, 指節7, 指節8, 指節9,指節10, 指節11, 指節12, 指節13, 指節14, 触角(左), 触角(右), 頭部, 腹尾節"),
gr.Number(label="X軸の最大値", value=1920),
gr.Number(label="Y軸の最大値", value=1080),
gr.CheckboxGroup(
label="プロットするグラフを選択",
choices=["散布図", "軌跡図", "尤度グラフ"],
value=["散布図", "軌跡図", "尤度グラフ"],
type="value"
)
],
outputs=[
gr.Gallery(label="散布図"),
gr.File(label="ZIPダウンロード")
],
title="DeepLabCutグラフ出力ツール",
description="CSVファイルからグラフを作成します。"
)
def process_and_plot(self, file, bodypart_names, x_max, y_max, graph_choices):
processor = DataProcessor(bodypart_names, x_max, y_max)
df, df_likelihood = processor.process_csv(file.name)
all_image_paths = []
if "散布図" in graph_choices:
all_image_paths += processor.plot_scatter(df)
if "軌跡図" in graph_choices:
all_image_paths += processor.plot_trajectories(df)
if "尤度グラフ" in graph_choices:
all_image_paths += processor.plot_likelihood(df_likelihood)
shutil.make_archive(processor.output_folder,
'zip', processor.output_folder)
return all_image_paths, processor.output_folder + '.zip'
def launch(self):
self.interface.launch()
if __name__ == "__main__":
gradio_app = GradioInterface()
gradio_app.launch()