import torch from transformers import ViTForImageClassification, ViTImageProcessor import matplotlib.pyplot as plt import gradio as gr import plotly.graph_objects as go import torch import numpy as np from PIL import Image model_name = "./best_model" processor = ViTImageProcessor.from_pretrained(model_name) labels = ['Акне или розацеа', 'Актинический кератоз, базальноклеточная карцинома и другие злокачественные поражения', 'Атопический дерматит', 'Буллезное заболевание', 'Целлюлит, импетиго и другие бактериальные инфекции', 'Контактный дерматит', 'Экзема', 'Экзантемы и лекарственные высыпания', 'Фотографии потери волос, алопеция и другие заболевания волос', 'Герпес, ВПЧ и другие ЗППП', 'Легкие заболевания и нарушения пигментации', 'Волчанка и другие заболевания соединительной ткани', 'Меланома, рак кожи, невусы и родинки', 'Грибок ногтей и другие заболевания ногтей', 'Фотографии псориаза, красный плоский лишай и связанные с ним заболевания', 'Чесотка, болезнь Лайма и другие инвазии и укусы', 'Себорейный кератоз и другие Доброкачественные опухоли', 'Системные заболевания', 'Опоясывающий лишай, кандидоз и другие грибковые инфекции', 'Крапивница', 'Сосудистые опухоли', 'Васкулит', 'Бородавки, моллюск и другие вирусные инфекции'] class ViTForImageClassificationWithAttention(ViTForImageClassification): def forward(self, pixel_values): outputs = super().forward(pixel_values) attention = self.vit.encoder.layers[0].attention.attention_weights return outputs, attention model = ViTForImageClassificationWithAttention.from_pretrained(model_name) class ViTForImageClassificationWithAttention(ViTForImageClassification): def forward(self, pixel_values, output_attentions=True): outputs = super().forward(pixel_values, output_attentions=output_attentions) attention = outputs.attentions return outputs, attention model = ViTForImageClassificationWithAttention.from_pretrained(model_name,attn_implementation="eager") i_count = 0 def classify_image(image): model_name = "best_model.pth" model.load_state_dict(torch.load(model_name)) inputs = processor(images=image, return_tensors="pt") outputs, attention = model(**inputs, output_attentions=True) logits = outputs.logits probs = torch.nn.functional.softmax(logits, dim=1) top_k_probs, top_k_indices = torch.topk(probs, k=5) # show top 5 predicted labels predicted_class_idx = torch.argmax(logits) predicted_class_label = labels[predicted_class_idx] top_k_labels = [labels[idx] for idx in top_k_indices[0]] top_k_label_probs = [(label, prob.item()) for label, prob in zip(top_k_labels, top_k_probs[0])] # Create a bar chart fig_bar = go.Figure( data=[go.Bar(x=[label for label, prob in top_k_label_probs], y=[prob for label, prob in top_k_label_probs])]) fig_bar.update_layout(title="Топ 5 диагнозов в порядке убывания вероятности", xaxis_title="Диагноз", yaxis_title="Вероятность") # Create a heatmap if attention is not None: fig_heatmap = go.Figure( data=[go.Heatmap(z=attention[0][0, 0, :, :].detach().numpy(), colorscale='Viridis', showscale=False)]) fig_heatmap.update_layout(title="Карта внимания системы") else: fig_heatmap = go.Figure() # Return an empty plot # Overlay the attention heatmap on the input image if attention is not None: img_array = np.array(image) heatmap = np.array(attention[0][0, 0, :, :].detach().numpy()) heatmap = np.resize(heatmap, (img_array.shape[0], img_array.shape[1])) heatmap = heatmap / heatmap.max() * 255 # Normalize heatmap to [0, 255] heatmap = heatmap.astype(np.uint8) heatmap_color = np.zeros((img_array.shape[0], img_array.shape[1], 3), dtype=np.uint8) heatmap_color[:, :, 0] = heatmap # Red channel heatmap_color[:, :, 1] = heatmap # Green channel heatmap_color[:, :, 2] = 0 # Blue channel attention_overlay = (img_array * 0.35 + heatmap_color * 0.75).astype(np.uint8) attention_overlay = Image.fromarray(attention_overlay) attention_overlay.save("attention_overlay.png") attention_overlay = gr.Image("attention_overlay.png") else: attention_overlay = gr.Image() # Return an empty image # Return the predicted label, the bar chart, and the heatmap return predicted_class_label, fig_bar, fig_heatmap, attention_overlay def update_model(image, label): # Convert the label to an integer label_idx = labels.index(label) labels_tensor = torch.tensor([label_idx]) inputs = processor(images=image, return_tensors="pt") loss_fn = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # Zero the gradients optimizer.zero_grad() # Forward pass outputs, attention = model(**inputs) loss = loss_fn(outputs.logits, labels_tensor) # Backward pass loss.backward() # Update the model parameters optimizer.step() # Save the updated model torch.save(model.state_dict(), "best_model.pth") return "Модель успешно обновлена" demo = gr.TabbedInterface( [ gr.Interface( fn=classify_image, inputs=[ gr.Image(type="pil", label="Image") ], outputs=[ gr.Label(label="Предсказанный диагноз"), gr.Plot(label="Топ 5 диагнозов в порядке убывания вероятности") ], title="DermaScan Demo", description="Загрузите изображение, чтобы увидеть прогнозируемую метку класса, 5 лучших прогнозируемых меток с вероятностями и тепловую карту внимания.", allow_flagging=False ), gr.Interface( fn=update_model, inputs=[ gr.Image(type="pil", label="Image"), gr.Radio( choices=labels, type="value", label="Label", value=labels[0] ) ], outputs=[ gr.Textbox(label="Обновление модели") ], title="Обучить модель", description="Загрузите изображение и метку для обновления модели.", allow_flagging=False ) ], title="DermaScan Demo" ) if __name__ == "__main__": demo.launch()