yeq6x's picture
init
02ba63a
raw
history blame
9.64 kB
import gradio as gr
import spaces
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from model_module import AutoencoderModule
from dataset import MyDataset, load_filenames
from utils import DistanceMapLogger
import numpy as np
from PIL import Image
import base64
from io import BytesIO
# モデルとデータの読み込み
def load_model():
model_path = "checkpoints/ae_model_tf_2024-03-05_00-35-21.pth"
feature_dim = 32
model = AutoencoderModule(feature_dim=feature_dim)
state_dict = torch.load(model_path)
# state_dict のキーを修正
new_state_dict = {}
for key in state_dict:
new_key = "model." + key
new_state_dict[new_key] = state_dict[key]
model.load_state_dict(new_state_dict)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print("Model loaded successfully.")
return model, device
def load_data(device, img_dir="resources/trainB/", image_size=112, batch_size=32):
filenames = load_filenames(img_dir)
train_X = filenames[:1000]
train_ds = MyDataset(train_X, img_dir=img_dir, img_size=image_size)
train_loader = DataLoader(
train_ds,
batch_size=batch_size,
shuffle=True,
num_workers=0,
)
iterator = iter(train_loader)
x, _, _ = next(iterator)
x = x.to(device)
x = x[:,0].to(device)
print("Data loaded successfully.")
return x
model, device = load_model()
image_size = 112
batch_size = 32
x = load_data(device)
# アップロード画像の前処理
def preprocess_uploaded_image(uploaded_image, image_size):
uploaded_image = Image.fromarray(uploaded_image)
uploaded_image = uploaded_image.convert("RGB")
uploaded_image = uploaded_image.resize((image_size, image_size))
uploaded_image = np.array(uploaded_image).transpose(2, 0, 1) / 255.0
uploaded_image = torch.tensor(uploaded_image, dtype=torch.float32).unsqueeze(0).to(device)
return uploaded_image
# ヒートマップの生成関数
@spaces.GPU
def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
with torch.no_grad():
dec5, _ = model(x)
img = x
feature_map = dec5
batch_size = feature_map.size(0)
feature_dim = feature_map.size(1)
# アップロード画像の前処理
if uploaded_image is not None:
uploaded_image = preprocess_uploaded_image(uploaded_image, image_size)
target_feature_map, _ = model(uploaded_image)
img = torch.cat((img, uploaded_image))
feature_map = torch.cat((feature_map, target_feature_map))
batch_size += 1
else:
uploaded_image = torch.zeros(1, 3, image_size, image_size, device=device)
target_num = batch_size - 1
x_coords = [x_coords] * batch_size
y_coords = [y_coords] * batch_size
vectors = feature_map[torch.arange(feature_map.size(0)), :, y_coords, x_coords]
vector = vectors[source_num]
reshaped_feature_map = feature_map.permute(0, 2, 3, 1).view(feature_map.size(0), -1, feature_dim)
batch_distance_map = F.pairwise_distance(reshaped_feature_map, vector).view(feature_map.size(0), image_size, image_size)
norm_batch_distance_map = 1 / torch.cosh(20 * (batch_distance_map - batch_distance_map.min()) / (batch_distance_map.max() - batch_distance_map.min())) ** 2
source_map = norm_batch_distance_map[source_num]
target_map = norm_batch_distance_map[target_num]
alpha = 0.8
blended_source = (1 - alpha) * img[source_num] + alpha * torch.cat(((norm_batch_distance_map[source_num] / norm_batch_distance_map[source_num].max()).unsqueeze(0), torch.zeros(2, image_size, image_size, device=device)))
blended_target = (1 - alpha) * img[target_num] + alpha * torch.cat(((norm_batch_distance_map[target_num] / norm_batch_distance_map[target_num].max()).unsqueeze(0), torch.zeros(2, image_size, image_size, device=device)))
# Matplotlibでプロットして画像として保存
fig, axs = plt.subplots(2, 2, figsize=(10, 10))
axs[0, 0].imshow(source_map.cpu(), cmap='hot')
axs[0, 0].set_title("Source Map")
axs[0, 1].imshow(target_map.cpu(), cmap='hot')
axs[0, 1].set_title("Target Map")
axs[1, 0].imshow(blended_source.permute(1, 2, 0).cpu())
axs[1, 0].set_title("Blended Source")
axs[1, 1].imshow(blended_target.permute(1, 2, 0).cpu())
axs[1, 1].set_title("Blended Target")
for ax in axs.flat:
ax.axis('off')
plt.tight_layout()
plt.close(fig)
return fig
def process_image(cropped_image_data):
# Base64からPILイメージに変換
header, base64_data = cropped_image_data.split(',', 1)
image_data = base64.b64decode(base64_data)
image = Image.open(BytesIO(image_data))
return image
# JavaScriptコード
scripts = """
async () => {
const script = document.createElement("script");
script.src = "https://cdnjs.cloudflare.com/ajax/libs/cropperjs/1.5.13/cropper.min.js";
document.head.appendChild(script);
const style = document.createElement("link");
style.rel = "stylesheet";
style.href = "https://cdnjs.cloudflare.com/ajax/libs/cropperjs/1.5.13/cropper.min.css";
document.head.appendChild(style);
script.onload = () => {
let cropper;
document.getElementById("input_file_button").onclick = function() {
document.querySelector("#input_file").click();
};
// GradioのFileコンポーネントから画像を読み込む
document.querySelector("#input_file").addEventListener("change", function(e) {
const files = e.target.files;
console.log(files);
if (files && files.length > 0) {
console.log("File selected");
document.querySelector("#crop_view").style.display = "block";
document.querySelector("#crop_button").style.display = "block";
const url = URL.createObjectURL(files[0]);
const crop_view = document.getElementById("crop_view");
crop_view.src = url;
if (cropper) {
cropper.destroy();
}
cropper = new Cropper(crop_view, {
aspectRatio: 1,
viewMode: 1,
});
}
});
// GradioボタンにJavaScriptの機能を追加
document.getElementById("crop_button").onclick = function() {
if (cropper) {
const canvas = cropper.getCroppedCanvas();
const croppedImageData = canvas.toDataURL();
// Gradioにクロップ画像を送信
const textbox = document.querySelector("#cropped_image_data textarea");
textbox.value = croppedImageData;
textbox.dispatchEvent(new Event("input", { bubbles: true }));
document.getElementById("crop_view").style.display = "none";
document.getElementById("crop_button").style.display = "none";
cropper.destroy();
}
};
document.getElementById("crop_view").style.display = "none";
document.getElementById("crop_button").style.display = "none";
};
}
"""
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
source_num = gr.Slider(0, batch_size - 1, step=1, label="Source Image Index")
x_coords = gr.Slider(0, image_size - 1, step=1, value=image_size // 2, label="X Coordinate")
y_coords = gr.Slider(0, image_size - 1, step=1, value=image_size // 2, label="Y Coordinate")
# GradioのFileコンポーネントでファイル選択ボタンを追加
gr.HTML('<input type="file" id="input_file" style="display:none;">')
input_file_button = gr.Button("画像を選択", elem_id="input_file_button")
# 画像を表示するためのHTML画像タグをGradioで表示
gr.HTML('<img id="crop_view" style="max-width:100%;">')
# Gradioのボタンコンポーネントを追加し、IDを付与
crop_button = gr.Button("クロップ", elem_id="crop_button", variant="primary")
# クロップされた画像データのテキストボックス(Base64データ)
cropped_image_data = gr.Textbox(visible=False, elem_id="cropped_image_data")
input_image = gr.Image(label="Cropped Image", interactive=False)
# cropped_image_dataが更新されたらprocess_imageを呼び出す
cropped_image_data.change(process_image, inputs=cropped_image_data, outputs=input_image)
with gr.Column():
output_plot = gr.Plot()
# Gradioインターフェースの代わり
source_num.change(get_heatmaps, inputs=[source_num, x_coords, y_coords, input_image], outputs=output_plot)
x_coords.change(get_heatmaps, inputs=[source_num, x_coords, y_coords, input_image], outputs=output_plot)
y_coords.change(get_heatmaps, inputs=[source_num, x_coords, y_coords, input_image], outputs=output_plot)
input_image.change(get_heatmaps, inputs=[source_num, x_coords, y_coords, input_image], outputs=output_plot)
# JavaScriptコードをロード
demo.load(None, None, None, js=scripts)
demo.launch()