import gradio as gr import spaces import torch import torch.nn.functional as F from torch.utils.data import DataLoader import matplotlib.pyplot as plt from model_module import AutoencoderModule from dataset import MyDataset, load_filenames from utils import DistanceMapLogger import numpy as np from PIL import Image import base64 from io import BytesIO # モデルとデータの読み込み def load_model(): model_path = "checkpoints/ae_model_tf_2024-03-05_00-35-21.pth" feature_dim = 32 model = AutoencoderModule(feature_dim=feature_dim) state_dict = torch.load(model_path) # state_dict のキーを修正 new_state_dict = {} for key in state_dict: new_key = "model." + key new_state_dict[new_key] = state_dict[key] model.load_state_dict(new_state_dict) model.eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) print("Model loaded successfully.") return model, device def load_data(device, img_dir="resources/trainB/", image_size=112, batch_size=32): filenames = load_filenames(img_dir) train_X = filenames[:1000] train_ds = MyDataset(train_X, img_dir=img_dir, img_size=image_size) train_loader = DataLoader( train_ds, batch_size=batch_size, shuffle=True, num_workers=0, ) iterator = iter(train_loader) x, _, _ = next(iterator) x = x.to(device) x = x[:,0].to(device) print("Data loaded successfully.") return x model, device = load_model() image_size = 112 batch_size = 32 x = load_data(device) # アップロード画像の前処理 def preprocess_uploaded_image(uploaded_image, image_size): uploaded_image = Image.fromarray(uploaded_image) uploaded_image = uploaded_image.convert("RGB") uploaded_image = uploaded_image.resize((image_size, image_size)) uploaded_image = np.array(uploaded_image).transpose(2, 0, 1) / 255.0 uploaded_image = torch.tensor(uploaded_image, dtype=torch.float32).unsqueeze(0).to(device) return uploaded_image # ヒートマップの生成関数 @spaces.GPU def get_heatmaps(source_num, x_coords, y_coords, uploaded_image): with torch.no_grad(): dec5, _ = model(x) img = x feature_map = dec5 batch_size = feature_map.size(0) feature_dim = feature_map.size(1) # アップロード画像の前処理 if uploaded_image is not None: uploaded_image = preprocess_uploaded_image(uploaded_image, image_size) target_feature_map, _ = model(uploaded_image) img = torch.cat((img, uploaded_image)) feature_map = torch.cat((feature_map, target_feature_map)) batch_size += 1 else: uploaded_image = torch.zeros(1, 3, image_size, image_size, device=device) target_num = batch_size - 1 x_coords = [x_coords] * batch_size y_coords = [y_coords] * batch_size vectors = feature_map[torch.arange(feature_map.size(0)), :, y_coords, x_coords] vector = vectors[source_num] reshaped_feature_map = feature_map.permute(0, 2, 3, 1).view(feature_map.size(0), -1, feature_dim) batch_distance_map = F.pairwise_distance(reshaped_feature_map, vector).view(feature_map.size(0), image_size, image_size) norm_batch_distance_map = 1 / torch.cosh(20 * (batch_distance_map - batch_distance_map.min()) / (batch_distance_map.max() - batch_distance_map.min())) ** 2 source_map = norm_batch_distance_map[source_num] target_map = norm_batch_distance_map[target_num] alpha = 0.8 blended_source = (1 - alpha) * img[source_num] + alpha * torch.cat(((norm_batch_distance_map[source_num] / norm_batch_distance_map[source_num].max()).unsqueeze(0), torch.zeros(2, image_size, image_size, device=device))) blended_target = (1 - alpha) * img[target_num] + alpha * torch.cat(((norm_batch_distance_map[target_num] / norm_batch_distance_map[target_num].max()).unsqueeze(0), torch.zeros(2, image_size, image_size, device=device))) # Matplotlibでプロットして画像として保存 fig, axs = plt.subplots(2, 2, figsize=(10, 10)) axs[0, 0].imshow(source_map.cpu(), cmap='hot') axs[0, 0].set_title("Source Map") axs[0, 1].imshow(target_map.cpu(), cmap='hot') axs[0, 1].set_title("Target Map") axs[1, 0].imshow(blended_source.permute(1, 2, 0).cpu()) axs[1, 0].set_title("Blended Source") axs[1, 1].imshow(blended_target.permute(1, 2, 0).cpu()) axs[1, 1].set_title("Blended Target") for ax in axs.flat: ax.axis('off') plt.tight_layout() plt.close(fig) return fig def process_image(cropped_image_data): # Base64からPILイメージに変換 header, base64_data = cropped_image_data.split(',', 1) image_data = base64.b64decode(base64_data) image = Image.open(BytesIO(image_data)) return image # JavaScriptコード scripts = """ async () => { const script = document.createElement("script"); script.src = "https://cdnjs.cloudflare.com/ajax/libs/cropperjs/1.5.13/cropper.min.js"; document.head.appendChild(script); const style = document.createElement("link"); style.rel = "stylesheet"; style.href = "https://cdnjs.cloudflare.com/ajax/libs/cropperjs/1.5.13/cropper.min.css"; document.head.appendChild(style); script.onload = () => { let cropper; document.getElementById("input_file_button").onclick = function() { document.querySelector("#input_file").click(); }; // GradioのFileコンポーネントから画像を読み込む document.querySelector("#input_file").addEventListener("change", function(e) { const files = e.target.files; console.log(files); if (files && files.length > 0) { console.log("File selected"); document.querySelector("#crop_view").style.display = "block"; document.querySelector("#crop_button").style.display = "block"; const url = URL.createObjectURL(files[0]); const crop_view = document.getElementById("crop_view"); crop_view.src = url; if (cropper) { cropper.destroy(); } cropper = new Cropper(crop_view, { aspectRatio: 1, viewMode: 1, }); } }); // GradioボタンにJavaScriptの機能を追加 document.getElementById("crop_button").onclick = function() { if (cropper) { const canvas = cropper.getCroppedCanvas(); const croppedImageData = canvas.toDataURL(); // Gradioにクロップ画像を送信 const textbox = document.querySelector("#cropped_image_data textarea"); textbox.value = croppedImageData; textbox.dispatchEvent(new Event("input", { bubbles: true })); document.getElementById("crop_view").style.display = "none"; document.getElementById("crop_button").style.display = "none"; cropper.destroy(); } }; document.getElementById("crop_view").style.display = "none"; document.getElementById("crop_button").style.display = "none"; }; } """ with gr.Blocks() as demo: with gr.Row(): with gr.Column(): source_num = gr.Slider(0, batch_size - 1, step=1, label="Source Image Index") x_coords = gr.Slider(0, image_size - 1, step=1, value=image_size // 2, label="X Coordinate") y_coords = gr.Slider(0, image_size - 1, step=1, value=image_size // 2, label="Y Coordinate") # GradioのFileコンポーネントでファイル選択ボタンを追加 gr.HTML('') input_file_button = gr.Button("画像を選択", elem_id="input_file_button") # 画像を表示するためのHTML画像タグをGradioで表示 gr.HTML('') # Gradioのボタンコンポーネントを追加し、IDを付与 crop_button = gr.Button("クロップ", elem_id="crop_button", variant="primary") # クロップされた画像データのテキストボックス(Base64データ) cropped_image_data = gr.Textbox(visible=False, elem_id="cropped_image_data") input_image = gr.Image(label="Cropped Image", interactive=False) # cropped_image_dataが更新されたらprocess_imageを呼び出す cropped_image_data.change(process_image, inputs=cropped_image_data, outputs=input_image) with gr.Column(): output_plot = gr.Plot() # Gradioインターフェースの代わり source_num.change(get_heatmaps, inputs=[source_num, x_coords, y_coords, input_image], outputs=output_plot) x_coords.change(get_heatmaps, inputs=[source_num, x_coords, y_coords, input_image], outputs=output_plot) y_coords.change(get_heatmaps, inputs=[source_num, x_coords, y_coords, input_image], outputs=output_plot) input_image.change(get_heatmaps, inputs=[source_num, x_coords, y_coords, input_image], outputs=output_plot) # JavaScriptコードをロード demo.load(None, None, None, js=scripts) demo.launch()