DHEIVER's picture
Update app.py
f9f0754
raw
history blame
No virus
3.31 kB
import gradio as gr
import torch
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
from PIL import Image
import numpy as np
import cv2
class ThyroidTumorClassifierApp:
def __init__(self):
# Load the feature extractor and model
self.extractor = AutoFeatureExtractor.from_pretrained("SerdarHelli/ThyroidTumorClassificationModel")
self.model = AutoModelForImageClassification.from_pretrained("SerdarHelli/ThyroidTumorClassificationModel")
def classify_image(self, image):
# Preprocess the image using the extractor
inputs = self.extractor(images=image, return_tensors="pt")
# Pass the image through the model
outputs = self.model(**inputs)
# Get the class probabilities
logits = outputs.logits
# Calculate the final probabilities using softmax
probabilities = torch.softmax(logits, dim=1)
# Get the class with the highest probability
predicted_class = torch.argmax(probabilities, dim=1).item()
# Customize class labels based on your model
class_labels = ["Sem Tumor", "Tumor"]
# Predicted class label
predicted_label = class_labels[predicted_class]
# Add information to the output image using OpenCV
output_image_with_info = self.add_info_to_image(image, predicted_label, probabilities)
# Return the modified output image as an array
return output_image_with_info
def add_info_to_image(self, image, predicted_label, probabilities):
# Convert the image to RGB format
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Add the predicted class label and probabilities to the image using OpenCV
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.6
font_thickness = 1
text_color = (255, 255, 255)
text_position = (10, 30)
# Add predicted class label
cv2.putText(image_rgb, f"Classe Prevista: {predicted_label}", text_position, font, font_scale, text_color, font_thickness)
# Add class probabilities
for i, prob in enumerate(probabilities[0]):
y_offset = 60 + i * 30
class_name = f"Classe {i}:"
probability = f"{prob:.2f}"
cv2.putText(image_rgb, f"{class_name} {probability}", (10, text_position[1] + y_offset), font, font_scale, text_color, font_thickness)
# Convert back to BGR format for display
output_image = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
return output_image
def run_interface(self):
# Create a Gradio interface
input_interface = gr.Interface(
fn=self.classify_image,
inputs=gr.inputs.Image(),
outputs=gr.outputs.Image(),
title="Tumor da Tireoide Classificação",
description="Faça o upload de uma imagem de um tumor da tireoide para classificação. A saída inclui o rótulo da classe prevista e as probabilidades com informações adicionais.",
)
# Launch the Gradio interface
input_interface.launch()
if __name__ == "__main__":
app = ThyroidTumorClassifierApp()
app.run_interface()