fawadkhan commited on
Commit
13e0bbf
1 Parent(s): 7bd25ac

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +57 -6
README.md CHANGED
@@ -29,13 +29,64 @@ You can use this model directly with a pipeline for image classification:
29
  \`\`\`python
30
 
31
 
32
- from transformers import pipeline
33
- classifier = pipeline('image-classification', model='fawadkhan/ViT_FineTuned_on_ImagesOASIS')
 
 
 
34
 
35
- # To classify an image:
36
- image_path = 'path_to_your_image.jpg'
37
- prediction = classifier(image_path)
38
- print(prediction)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
 
41
  \`\`\`
 
29
  \`\`\`python
30
 
31
 
32
+ import torch
33
+ from transformers import ViTForImageClassification
34
+ from PIL import Image
35
+ import numpy as np
36
+ from torchvision.transforms import Compose, Resize, ToTensor, Normalize
37
 
38
+
39
+ id2label = {
40
+ 0: "Non-Demented",
41
+ 1: "Very Mild Demented",
42
+ 2: "Mild Demented",
43
+ 3: "Demented"
44
+ }
45
+
46
+ import torch
47
+ from transformers import ViTForImageClassification
48
+ from PIL import Image
49
+ import numpy as np
50
+ from torchvision.transforms import Compose, Resize, ToTensor, Normalize
51
+ import matplotlib.pyplot as plt
52
+
53
+ # Set the device
54
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
+
56
+ # Load the model
57
+ model = ViTForImageClassification.from_pretrained('fawadkhan/ViT_FineTuned_on_ImagesOASIS')
58
+ model.to(device)
59
+ model.eval()
60
+
61
+ # Define the image path
62
+ image_path = 'your image path.jpg'
63
+ image = Image.open(image_path).convert("RGB")
64
+
65
+ # Define the transformations
66
+ transform = Compose([
67
+ Resize((224, 224)), # or the original input size of your model
68
+ ToTensor(),
69
+ Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Standard normalization for ImageNet
70
+ ])
71
+
72
+ # Preprocess the image
73
+ input_tensor = transform(image).unsqueeze(0) # Create a mini-batch as expected by the model
74
+ input_tensor = input_tensor.to(device)
75
+
76
+ # Predict
77
+ with torch.no_grad():
78
+ outputs = model(input_tensor)
79
+ _, predicted = torch.max(outputs.logits, 1)
80
+
81
+ # Retrieve the class name
82
+ predicted_class = id2label[predicted[0].item()]
83
+ print("Predicted class:", predicted_class)
84
+
85
+ # Plot the image and the prediction
86
+ plt.imshow(image)
87
+ plt.title(f'Predicted class: {predicted_class}')
88
+ plt.axis('off') # Turn off axis numbers and ticks
89
+ plt.show()
90
 
91
 
92
  \`\`\`