Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
import torch | |
import torch.nn.functional as F | |
from torch.utils.data import DataLoader | |
import matplotlib.pyplot as plt | |
from model_module import AutoencoderModule | |
from dataset import MyDataset, load_filenames | |
from utils import DistanceMapLogger | |
import numpy as np | |
from PIL import Image | |
import base64 | |
from io import BytesIO | |
# モデルとデータの読み込み | |
def load_model(): | |
model_path = "checkpoints/ae_model_tf_2024-03-05_00-35-21.pth" | |
feature_dim = 32 | |
model = AutoencoderModule(feature_dim=feature_dim) | |
state_dict = torch.load(model_path) | |
# state_dict のキーを修正 | |
new_state_dict = {} | |
for key in state_dict: | |
new_key = "model." + key | |
new_state_dict[new_key] = state_dict[key] | |
model.load_state_dict(new_state_dict) | |
model.eval() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
print("Model loaded successfully.") | |
return model, device | |
def load_data(device, img_dir="resources/trainB/", image_size=112, batch_size=32): | |
filenames = load_filenames(img_dir) | |
train_X = filenames[:1000] | |
train_ds = MyDataset(train_X, img_dir=img_dir, img_size=image_size) | |
train_loader = DataLoader( | |
train_ds, | |
batch_size=batch_size, | |
shuffle=True, | |
num_workers=0, | |
) | |
iterator = iter(train_loader) | |
x, _, _ = next(iterator) | |
x = x.to(device) | |
x = x[:,0].to(device) | |
print("Data loaded successfully.") | |
return x | |
model, device = load_model() | |
image_size = 112 | |
batch_size = 32 | |
x = load_data(device) | |
# アップロード画像の前処理 | |
def preprocess_uploaded_image(uploaded_image, image_size): | |
uploaded_image = Image.fromarray(uploaded_image) | |
uploaded_image = uploaded_image.convert("RGB") | |
uploaded_image = uploaded_image.resize((image_size, image_size)) | |
uploaded_image = np.array(uploaded_image).transpose(2, 0, 1) / 255.0 | |
uploaded_image = torch.tensor(uploaded_image, dtype=torch.float32).unsqueeze(0).to(device) | |
return uploaded_image | |
# ヒートマップの生成関数 | |
def get_heatmaps(source_num, x_coords, y_coords, uploaded_image): | |
with torch.no_grad(): | |
dec5, _ = model(x) | |
img = x | |
feature_map = dec5 | |
batch_size = feature_map.size(0) | |
feature_dim = feature_map.size(1) | |
# アップロード画像の前処理 | |
if uploaded_image is not None: | |
uploaded_image = preprocess_uploaded_image(uploaded_image, image_size) | |
target_feature_map, _ = model(uploaded_image) | |
img = torch.cat((img, uploaded_image)) | |
feature_map = torch.cat((feature_map, target_feature_map)) | |
batch_size += 1 | |
else: | |
uploaded_image = torch.zeros(1, 3, image_size, image_size, device=device) | |
target_num = batch_size - 1 | |
x_coords = [x_coords] * batch_size | |
y_coords = [y_coords] * batch_size | |
vectors = feature_map[torch.arange(feature_map.size(0)), :, y_coords, x_coords] | |
vector = vectors[source_num] | |
reshaped_feature_map = feature_map.permute(0, 2, 3, 1).view(feature_map.size(0), -1, feature_dim) | |
batch_distance_map = F.pairwise_distance(reshaped_feature_map, vector).view(feature_map.size(0), image_size, image_size) | |
norm_batch_distance_map = 1 / torch.cosh(20 * (batch_distance_map - batch_distance_map.min()) / (batch_distance_map.max() - batch_distance_map.min())) ** 2 | |
source_map = norm_batch_distance_map[source_num] | |
target_map = norm_batch_distance_map[target_num] | |
alpha = 0.8 | |
blended_source = (1 - alpha) * img[source_num] + alpha * torch.cat(((norm_batch_distance_map[source_num] / norm_batch_distance_map[source_num].max()).unsqueeze(0), torch.zeros(2, image_size, image_size, device=device))) | |
blended_target = (1 - alpha) * img[target_num] + alpha * torch.cat(((norm_batch_distance_map[target_num] / norm_batch_distance_map[target_num].max()).unsqueeze(0), torch.zeros(2, image_size, image_size, device=device))) | |
# Matplotlibでプロットして画像として保存 | |
fig, axs = plt.subplots(2, 2, figsize=(10, 10)) | |
axs[0, 0].imshow(source_map.cpu(), cmap='hot') | |
axs[0, 0].set_title("Source Map") | |
axs[0, 1].imshow(target_map.cpu(), cmap='hot') | |
axs[0, 1].set_title("Target Map") | |
axs[1, 0].imshow(blended_source.permute(1, 2, 0).cpu()) | |
axs[1, 0].set_title("Blended Source") | |
axs[1, 1].imshow(blended_target.permute(1, 2, 0).cpu()) | |
axs[1, 1].set_title("Blended Target") | |
for ax in axs.flat: | |
ax.axis('off') | |
plt.tight_layout() | |
plt.close(fig) | |
return fig | |
def process_image(cropped_image_data): | |
# Base64からPILイメージに変換 | |
header, base64_data = cropped_image_data.split(',', 1) | |
image_data = base64.b64decode(base64_data) | |
image = Image.open(BytesIO(image_data)) | |
return image | |
# JavaScriptコード | |
scripts = """ | |
async () => { | |
const script = document.createElement("script"); | |
script.src = "https://cdnjs.cloudflare.com/ajax/libs/cropperjs/1.5.13/cropper.min.js"; | |
document.head.appendChild(script); | |
const style = document.createElement("link"); | |
style.rel = "stylesheet"; | |
style.href = "https://cdnjs.cloudflare.com/ajax/libs/cropperjs/1.5.13/cropper.min.css"; | |
document.head.appendChild(style); | |
script.onload = () => { | |
let cropper; | |
document.getElementById("input_file_button").onclick = function() { | |
document.querySelector("#input_file").click(); | |
}; | |
// GradioのFileコンポーネントから画像を読み込む | |
document.querySelector("#input_file").addEventListener("change", function(e) { | |
const files = e.target.files; | |
console.log(files); | |
if (files && files.length > 0) { | |
console.log("File selected"); | |
document.querySelector("#crop_view").style.display = "block"; | |
document.querySelector("#crop_button").style.display = "block"; | |
const url = URL.createObjectURL(files[0]); | |
const crop_view = document.getElementById("crop_view"); | |
crop_view.src = url; | |
if (cropper) { | |
cropper.destroy(); | |
} | |
cropper = new Cropper(crop_view, { | |
aspectRatio: 1, | |
viewMode: 1, | |
}); | |
} | |
}); | |
// GradioボタンにJavaScriptの機能を追加 | |
document.getElementById("crop_button").onclick = function() { | |
if (cropper) { | |
const canvas = cropper.getCroppedCanvas(); | |
const croppedImageData = canvas.toDataURL(); | |
// Gradioにクロップ画像を送信 | |
const textbox = document.querySelector("#cropped_image_data textarea"); | |
textbox.value = croppedImageData; | |
textbox.dispatchEvent(new Event("input", { bubbles: true })); | |
document.getElementById("crop_view").style.display = "none"; | |
document.getElementById("crop_button").style.display = "none"; | |
cropper.destroy(); | |
} | |
}; | |
document.getElementById("crop_view").style.display = "none"; | |
document.getElementById("crop_button").style.display = "none"; | |
}; | |
} | |
""" | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
source_num = gr.Slider(0, batch_size - 1, step=1, label="Source Image Index") | |
x_coords = gr.Slider(0, image_size - 1, step=1, value=image_size // 2, label="X Coordinate") | |
y_coords = gr.Slider(0, image_size - 1, step=1, value=image_size // 2, label="Y Coordinate") | |
# GradioのFileコンポーネントでファイル選択ボタンを追加 | |
gr.HTML('<input type="file" id="input_file" style="display:none;">') | |
input_file_button = gr.Button("画像を選択", elem_id="input_file_button") | |
# 画像を表示するためのHTML画像タグをGradioで表示 | |
gr.HTML('<img id="crop_view" style="max-width:100%;">') | |
# Gradioのボタンコンポーネントを追加し、IDを付与 | |
crop_button = gr.Button("クロップ", elem_id="crop_button", variant="primary") | |
# クロップされた画像データのテキストボックス(Base64データ) | |
cropped_image_data = gr.Textbox(visible=False, elem_id="cropped_image_data") | |
input_image = gr.Image(label="Cropped Image", interactive=False) | |
# cropped_image_dataが更新されたらprocess_imageを呼び出す | |
cropped_image_data.change(process_image, inputs=cropped_image_data, outputs=input_image) | |
with gr.Column(): | |
output_plot = gr.Plot() | |
# Gradioインターフェースの代わり | |
source_num.change(get_heatmaps, inputs=[source_num, x_coords, y_coords, input_image], outputs=output_plot) | |
x_coords.change(get_heatmaps, inputs=[source_num, x_coords, y_coords, input_image], outputs=output_plot) | |
y_coords.change(get_heatmaps, inputs=[source_num, x_coords, y_coords, input_image], outputs=output_plot) | |
input_image.change(get_heatmaps, inputs=[source_num, x_coords, y_coords, input_image], outputs=output_plot) | |
# JavaScriptコードをロード | |
demo.load(None, None, None, js=scripts) | |
demo.launch() | |