DermaTestDemo / app.py
ZDPLI's picture
Update app.py
f349e34 verified
import torch
from transformers import ViTForImageClassification, ViTImageProcessor
import matplotlib.pyplot as plt
import gradio as gr
import plotly.graph_objects as go
import torch
import numpy as np
from PIL import Image
model_name = "./best_model"
processor = ViTImageProcessor.from_pretrained(model_name)
labels = ['Акне или розацеа', 'Актинический кератоз, базальноклеточная карцинома и другие злокачественные поражения', 'Атопический дерматит', 'Буллезное заболевание', 'Целлюлит, импетиго и другие бактериальные инфекции', 'Контактный дерматит', 'Экзема', 'Экзантемы и лекарственные высыпания', 'Фотографии потери волос, алопеция и другие заболевания волос', 'Герпес, ВПЧ и другие ЗППП', 'Легкие заболевания и нарушения пигментации', 'Волчанка и другие заболевания соединительной ткани', 'Меланома, рак кожи, невусы и родинки', 'Грибок ногтей и другие заболевания ногтей', 'Фотографии псориаза, красный плоский лишай и связанные с ним заболевания', 'Чесотка, болезнь Лайма и другие инвазии и укусы', 'Себорейный кератоз и другие Доброкачественные опухоли', 'Системные заболевания', 'Опоясывающий лишай, кандидоз и другие грибковые инфекции', 'Крапивница', 'Сосудистые опухоли', 'Васкулит', 'Бородавки, моллюск и другие вирусные инфекции']
class ViTForImageClassificationWithAttention(ViTForImageClassification):
def forward(self, pixel_values):
outputs = super().forward(pixel_values)
attention = self.vit.encoder.layers[0].attention.attention_weights
return outputs, attention
model = ViTForImageClassificationWithAttention.from_pretrained(model_name)
class ViTForImageClassificationWithAttention(ViTForImageClassification):
def forward(self, pixel_values, output_attentions=True):
outputs = super().forward(pixel_values, output_attentions=output_attentions)
attention = outputs.attentions
return outputs, attention
model = ViTForImageClassificationWithAttention.from_pretrained(model_name,attn_implementation="eager")
i_count = 0
def classify_image(image):
model_name = "best_model.pth"
model.load_state_dict(torch.load(model_name))
inputs = processor(images=image, return_tensors="pt")
outputs, attention = model(**inputs, output_attentions=True)
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=1)
top_k_probs, top_k_indices = torch.topk(probs, k=5) # show top 5 predicted labels
predicted_class_idx = torch.argmax(logits)
predicted_class_label = labels[predicted_class_idx]
top_k_labels = [labels[idx] for idx in top_k_indices[0]]
top_k_label_probs = [(label, prob.item()) for label, prob in zip(top_k_labels, top_k_probs[0])]
# Create a bar chart
fig_bar = go.Figure(
data=[go.Bar(x=[label for label, prob in top_k_label_probs], y=[prob for label, prob in top_k_label_probs])])
fig_bar.update_layout(title="Топ 5 диагнозов в порядке убывания вероятности", xaxis_title="Диагноз",
yaxis_title="Вероятность")
# Create a heatmap
if attention is not None:
fig_heatmap = go.Figure(
data=[go.Heatmap(z=attention[0][0, 0, :, :].detach().numpy(), colorscale='Viridis', showscale=False)])
fig_heatmap.update_layout(title="Карта внимания системы")
else:
fig_heatmap = go.Figure() # Return an empty plot
# Overlay the attention heatmap on the input image
if attention is not None:
img_array = np.array(image)
heatmap = np.array(attention[0][0, 0, :, :].detach().numpy())
heatmap = np.resize(heatmap, (img_array.shape[0], img_array.shape[1]))
heatmap = heatmap / heatmap.max() * 255 # Normalize heatmap to [0, 255]
heatmap = heatmap.astype(np.uint8)
heatmap_color = np.zeros((img_array.shape[0], img_array.shape[1], 3), dtype=np.uint8)
heatmap_color[:, :, 0] = heatmap # Red channel
heatmap_color[:, :, 1] = heatmap # Green channel
heatmap_color[:, :, 2] = 0 # Blue channel
attention_overlay = (img_array * 0.35 + heatmap_color * 0.75).astype(np.uint8)
attention_overlay = Image.fromarray(attention_overlay)
attention_overlay.save("attention_overlay.png")
attention_overlay = gr.Image("attention_overlay.png")
else:
attention_overlay = gr.Image() # Return an empty image
# Return the predicted label, the bar chart, and the heatmap
return predicted_class_label, fig_bar, fig_heatmap, attention_overlay
def update_model(image, label):
# Convert the label to an integer
label_idx = labels.index(label)
labels_tensor = torch.tensor([label_idx])
inputs = processor(images=image, return_tensors="pt")
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# Zero the gradients
optimizer.zero_grad()
# Forward pass
outputs, attention = model(**inputs)
loss = loss_fn(outputs.logits, labels_tensor)
# Backward pass
loss.backward()
# Update the model parameters
optimizer.step()
# Save the updated model
torch.save(model.state_dict(), "best_model.pth")
return "Модель успешно обновлена"
demo = gr.TabbedInterface(
[
gr.Interface(
fn=classify_image,
inputs=[
gr.Image(type="pil", label="Image")
],
outputs=[
gr.Label(label="Предсказанный диагноз"),
gr.Plot(label="Топ 5 диагнозов в порядке убывания вероятности")
],
title="DermaScan Demo",
description="Загрузите изображение, чтобы увидеть прогнозируемую метку класса, 5 лучших прогнозируемых меток с вероятностями и тепловую карту внимания.",
allow_flagging=False
),
gr.Interface(
fn=update_model,
inputs=[
gr.Image(type="pil", label="Image"),
gr.Radio(
choices=labels,
type="value",
label="Label",
value=labels[0]
)
],
outputs=[
gr.Textbox(label="Обновление модели")
],
title="Обучить модель",
description="Загрузите изображение и метку для обновления модели.",
allow_flagging=False
)
],
title="DermaScan Demo"
)
if __name__ == "__main__":
demo.launch()