DHEIVER commited on
Commit
990cc09
1 Parent(s): 96c9ba0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -1
app.py CHANGED
@@ -1,3 +1,62 @@
1
  import gradio as gr
 
 
 
 
 
2
 
3
- gr.Interface.load("models/AhmadHakami/alzheimer-image-classification-google-vit-base-patch16").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import datetime
3
+ from transformers import AutoFeatureExtractor, AutoModelForImageClassification
4
+ import torch
5
+ from PIL import Image
6
+ import torchvision.transforms as transforms
7
 
8
+ # Carregar o extrator de recursos e o modelo diretamente
9
+ extractor = AutoFeatureExtractor.from_pretrained("AhmadHakami/alzheimer-image-classification-google-vit-base-patch16")
10
+ model = AutoModelForImageClassification.from_pretrained("AhmadHakami/alzheimer-image-classification-google-vit-base-patch16")
11
+
12
+ # Mapeamento de classe ID para rótulo
13
+ id2label = {
14
+ "0": "Mild_Demented",
15
+ "1": "Moderate_Demented",
16
+ "2": "Non_Demented",
17
+ "3": "Very_Mild_Demented"
18
+ }
19
+
20
+ # Função para classificar a imagem
21
+ def classify_image(input_image):
22
+ # Pré-processar a imagem
23
+ image = Image.fromarray(input_image)
24
+
25
+ # Aplicar transformações à imagem
26
+ transform = transforms.Compose([
27
+ transforms.Resize((224, 224)),
28
+ transforms.ToTensor(),
29
+ ])
30
+ input_image = transform(image).unsqueeze(0)
31
+
32
+ # Realizar a classificação de imagem
33
+ inputs = extractor(images=input_image, return_tensors="pt")
34
+ outputs = model(**inputs)
35
+
36
+ # Obter a classe prevista a partir da saída do modelo
37
+ predicted_class_id = torch.argmax(outputs.logits, dim=1).item()
38
+ predicted_class_label = id2label.get(str(predicted_class_id), "Desconhecido")
39
+
40
+ # Obter a data e hora atual
41
+ current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
42
+
43
+ # Formatar a saída em HTML com rótulo da classe e data/hora
44
+ result_html = f"""
45
+ <h2>Resultado da Classificação</h2>
46
+ <p><strong>Rótulo da Classe:</strong> {predicted_class_label}</p>
47
+ <p><strong>Data e Hora:</strong> {current_time}</p>
48
+ """
49
+
50
+ return result_html
51
+
52
+ # Criar uma interface Gradio com a função de classificação de imagem
53
+ iface = gr.Interface(
54
+ fn=classify_image,
55
+ inputs=gr.inputs.Image(type="numpy", label="Carregar uma imagem"),
56
+ outputs=gr.outputs.HTML(), # Saída formatada com HTML
57
+ title="Classificador de Imagem ViT para Demência",
58
+ description="Esta aplicação Gradio permite classificar imagens relacionadas à demência usando um modelo Vision Transformer (ViT).",
59
+ )
60
+
61
+ # Iniciar a aplicação Gradio
62
+ iface.launch()