Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
import os | |
import shutil | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import japanize_matplotlib | |
class DataProcessor: | |
def __init__(self, x_max, y_max): | |
self.x_max = x_max | |
self.y_max = y_max | |
self.output_folder = 'output_plots' | |
self.bodypart_names = None # 初期化時にはNoneに設定 | |
def process_csv(self, file_path): | |
df = pd.read_csv(file_path, header=[1, 2]) | |
# CSVから自動的に付属肢名を抽出 | |
self.bodypart_names = df.columns.get_level_values(0).unique().tolist() | |
# 最初の列(通常はscorerなど)を除外 | |
if len(self.bodypart_names) > 0: | |
self.bodypart_names = self.bodypart_names[1:] | |
df_likelihood = self.extract_likelihood(df) | |
df = self.remove_first_column_and_likelihood(df) | |
return df, df_likelihood, self.bodypart_names # 抽出した付属肢名も返す | |
def remove_first_column_and_likelihood(self, df): | |
df = df.drop(df.columns[0], axis=1) | |
df = df[df.columns.drop(list(df.filter(regex='likelihood')))] | |
return df | |
def extract_likelihood(self, df): | |
# likelihood列のみを抽出する | |
df = df[df.columns[df.columns.get_level_values(1) == 'likelihood']] | |
df.drop(df.columns[0], axis=1) | |
return df | |
def plot_scatter(self, df): | |
if not os.path.exists(self.output_folder): | |
os.makedirs(self.output_folder) | |
return self.plot_scatter_fixed(df, self.output_folder, self.x_max, self.y_max) | |
def plot_scatter_fixed(self, df, output_folder, x_max, y_max): | |
image_paths = [] | |
bodyparts = df.columns.get_level_values(0).unique() | |
colors = plt.cm.rainbow(np.linspace(0, 1, len(bodyparts))) | |
for i, bodypart in enumerate(bodyparts): | |
x = df[bodypart]['x'].values | |
y = df[bodypart]['y'].values | |
plt.figure(figsize=(8, 6)) | |
plt.scatter(x, y, color=colors[i], label=bodypart) | |
plt.scatter(x[0], y[0], color='black', marker='o', s=100) | |
plt.text(x[0], y[0], ' Start', color='black', | |
fontsize=12, verticalalignment='bottom') | |
plt.xlim(0, x_max) | |
plt.ylim(0, y_max) | |
plt.gca().invert_yaxis() | |
plt.title(f'トラッキングの座標({bodypart})') | |
plt.xlabel('X Coordinate(pixel)') | |
plt.ylabel('Y Coordinate(pixel)') | |
plt.legend(loc='upper right') | |
plt.savefig(f'{output_folder}/{bodypart}.png') | |
image_paths.append(f'{output_folder}/{bodypart}.png') | |
plt.close() | |
plt.figure(figsize=(8, 6)) | |
for i, bodypart in enumerate(bodyparts): | |
x = df[bodypart]['x'].values | |
y = df[bodypart]['y'].values | |
plt.scatter(x, y, color=colors[i], label=bodypart) | |
plt.xlim(0, x_max) | |
plt.ylim(0, y_max) | |
plt.gca().invert_yaxis() | |
plt.title('トラッキングの座標(全付属肢)') | |
plt.xlabel('X Coordinate(pixel)') | |
plt.ylabel('Y Coordinate(pixel)') | |
plt.legend(loc='upper right') | |
plt.savefig(f'{output_folder}/all_plot.png') | |
image_paths.append(f'{output_folder}/all_plot.png') | |
plt.close() | |
return image_paths | |
def plot_trajectories(self, df): | |
image_paths = [] | |
bodyparts = df.columns.get_level_values(0).unique() | |
colors = plt.cm.rainbow(np.linspace(0, 1, len(bodyparts))) | |
for i, bodypart in enumerate(bodyparts): | |
x = df[bodypart]['x'].values | |
y = df[bodypart]['y'].values | |
plt.figure(figsize=(8, 6)) | |
plt.plot(x, color=colors[i], label=bodypart + | |
"(x座標)", linestyle='dashed') | |
plt.plot(y, color=colors[i], label=bodypart + "(y座標)") | |
plt.title(f'トラッキングの座標({bodypart})') | |
plt.xlabel('Frames') | |
plt.ylabel('Coordinate(pixel)') | |
plt.ylim(0, self.x_max) | |
plt.legend(loc='upper right') | |
plt.savefig(f'{self.output_folder}/{bodypart}_trajectories.png') | |
image_paths.append( | |
f'{self.output_folder}/{bodypart}_trajectories.png') | |
plt.close() | |
plt.figure(figsize=(8, 6)) | |
for i, bodypart in enumerate(bodyparts): | |
x = df[bodypart]['x'].values | |
y = df[bodypart]['y'].values | |
plt.plot(x, color=colors[i], label=bodypart + | |
"(x座標)", linestyle='dashed') | |
plt.plot(y, color=colors[i], label=bodypart + "(y座標)") | |
plt.title(f'トラッキングの座標(全付属肢)') | |
plt.xlabel('Frames') | |
plt.ylabel('Coordinate(pixel)') | |
plt.ylim(0, self.x_max) | |
plt.legend(loc='upper right') | |
plt.savefig(f'{self.output_folder}/all_trajectories.png') | |
image_paths.append(f'{self.output_folder}/all_trajectories.png') | |
plt.close() | |
return image_paths | |
def plot_likelihood(self, likelihood_df): | |
image_paths = [] | |
plt.ylim(0, 1.0) | |
bodyparts = likelihood_df.columns.get_level_values(0).unique() | |
colors = plt.cm.rainbow(np.linspace(0, 1, len(bodyparts))) | |
# 付属肢ごとに尤度をプロット | |
for i, bodypart in enumerate(bodyparts): | |
plt.figure(figsize=(8, 6)) | |
plt.ylim(0, 1.0) | |
plt.plot(likelihood_df[bodypart], color=colors[i], label=bodypart) | |
plt.xlabel('Frames') | |
plt.ylabel('尤度') | |
plt.title('フレーム別の尤度') | |
# 凡例を右上の外側に表示 | |
plt.legend(bbox_to_anchor=(1.05, 1), | |
loc='upper left', borderaxespad=0) | |
# 凡例がはみ出さないようにレイアウトを調整 | |
plt.tight_layout() | |
plt.savefig(f'{self.output_folder}/{bodypart}_likelihood.png') | |
image_paths.append(f'{self.output_folder}/{bodypart}_likelihood.png') | |
plt.close() | |
# 全ての付属肢の尤度をプロット | |
plt.figure(figsize=(8, 6)) | |
plt.ylim(0, 1.0) | |
for i, column in enumerate(likelihood_df.columns): | |
plt.plot(likelihood_df[column], color=colors[i], label=column[0]) | |
plt.xlabel('Frames') | |
plt.ylabel('尤度') | |
plt.title('フレーム別の尤度') | |
# 凡例を右上の外側に表示 | |
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0) | |
# 凡例がはみ出さないようにレイアウトを調整 | |
plt.tight_layout() | |
plt.savefig(f'{self.output_folder}/likelihood_plot.png') | |
plt.close() | |
image_paths.append(f'{self.output_folder}/likelihood_plot.png') | |
return image_paths | |
class GradioInterface: | |
def __init__(self): | |
self.interface = gr.Interface( | |
fn=self.process_and_plot, | |
inputs=[ | |
gr.File(label="CSVファイルをドラッグ&ドロップ"), | |
gr.Number(label="X軸の最大値", value=1920), | |
gr.Number(label="Y軸の最大値", value=1080), | |
gr.CheckboxGroup( | |
label="プロットするグラフを選択", | |
choices=["散布図", "軌跡図", "尤度グラフ"], | |
value=["散布図", "軌跡図", "尤度グラフ"], | |
type="value" | |
) | |
], | |
outputs=[ | |
gr.Gallery(label="グラフ"), | |
gr.File(label="ZIPダウンロード"), | |
gr.Textbox(label="検出された付属肢") # 検出された付属肢を表示するための出力を追加 | |
], | |
title="DeepLabCutグラフ出力ツール", | |
description="CSVファイルからグラフを作成します。付属肢はCSVファイルから自動的に抽出されます。" | |
) | |
def process_and_plot(self, file, x_max, y_max, graph_choices): | |
processor = DataProcessor(x_max, y_max) | |
df, df_likelihood, bodypart_names = processor.process_csv(file.name) | |
all_image_paths = [] | |
if "散布図" in graph_choices: | |
all_image_paths += processor.plot_scatter(df) | |
if "軌跡図" in graph_choices: | |
all_image_paths += processor.plot_trajectories(df) | |
if "尤度グラフ" in graph_choices: | |
all_image_paths += processor.plot_likelihood(df_likelihood) | |
# 付属肢の名前を表示用に結合 | |
bodyparts_text = ", ".join(bodypart_names) | |
shutil.make_archive(processor.output_folder, 'zip', processor.output_folder) | |
return all_image_paths, processor.output_folder + '.zip', bodyparts_text | |
def launch(self): | |
self.interface.launch() | |
if __name__ == "__main__": | |
gradio_app = GradioInterface() | |
gradio_app.launch() | |