RingoDingo
commited on
Commit
·
6588ad6
1
Parent(s):
6bd4036
Upload model_architecture.py
Browse files- model_architecture.py +76 -0
model_architecture.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy
|
3 |
+
from PIL import Image
|
4 |
+
from torchvision.transforms import ToTensor
|
5 |
+
from transformers import ViTModel, ViTFeatureExtractor
|
6 |
+
from transformers.modeling_outputs import SequenceClassifierOutput
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-large-patch32-384')
|
10 |
+
class ViTForImageClassification(nn.Module):
|
11 |
+
def __init__(self, num_labels=2):
|
12 |
+
super(ViTForImageClassification, self).__init__()
|
13 |
+
self.vit = ViTModel.from_pretrained('google/vit-large-patch32-384')
|
14 |
+
self.dropout = nn.Dropout(0.1)
|
15 |
+
self.classifier = nn.Linear(self.vit.config.hidden_size, num_labels)
|
16 |
+
self.num_labels = num_labels
|
17 |
+
|
18 |
+
def forward(self, pixel_values, labels=None):
|
19 |
+
outputs = self.vit(pixel_values=pixel_values)
|
20 |
+
output = self.dropout(outputs.last_hidden_state[:,0])
|
21 |
+
logits = self.classifier(output)
|
22 |
+
|
23 |
+
if labels is not None:
|
24 |
+
loss_fct = nn.CrossEntropyLoss()
|
25 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
26 |
+
return SequenceClassifierOutput(
|
27 |
+
loss=loss,
|
28 |
+
logits=logits,
|
29 |
+
)
|
30 |
+
else:
|
31 |
+
return logits
|
32 |
+
|
33 |
+
def preprocess_image(image, desired_size=384):
|
34 |
+
im = image
|
35 |
+
|
36 |
+
# Resize and pad the image
|
37 |
+
old_size = im.size
|
38 |
+
ratio = float(desired_size) / max(old_size)
|
39 |
+
new_size = tuple([int(x*ratio) for x in old_size])
|
40 |
+
im = im.resize(new_size)
|
41 |
+
|
42 |
+
# Create a new image and paste the resized on it
|
43 |
+
new_im = Image.new("RGB", (desired_size, desired_size), "white")
|
44 |
+
new_im.paste(im, ((desired_size-new_size[0])//2, (desired_size-new_size[1])//2))
|
45 |
+
return new_im
|
46 |
+
|
47 |
+
def predict_image(image, model, feature_extractor):
|
48 |
+
# Ensure model is in eval mode
|
49 |
+
model.eval()
|
50 |
+
|
51 |
+
# Convert image to tensor
|
52 |
+
transform = ToTensor()
|
53 |
+
input_tensor = transform(image)
|
54 |
+
input_tensor = torch.tensor(numpy.array(feature_extractor(input_tensor)['pixel_values']))
|
55 |
+
|
56 |
+
# Move tensors to the right device
|
57 |
+
input_tensor = input_tensor.cuda()
|
58 |
+
|
59 |
+
# Forward pass of the image through the model
|
60 |
+
output = model(input_tensor)
|
61 |
+
|
62 |
+
# Convert model output to probabilities using softmax
|
63 |
+
probabilities = torch.nn.functional.softmax(output, dim=1)
|
64 |
+
|
65 |
+
return probabilities.cpu().detach().numpy()
|
66 |
+
|
67 |
+
|
68 |
+
model = ViTForImageClassification(num_labels=2)
|
69 |
+
model.load_state_dict(torch.load("./AID96k_E15_384.pth"))
|
70 |
+
model.cuda()
|
71 |
+
model.eval()
|
72 |
+
img = Image.open("test.png")
|
73 |
+
img = preprocess_image(img)
|
74 |
+
probs = predict_image(img, model, feature_extractor)
|
75 |
+
print(f"AI: {probs[0][0]}")
|
76 |
+
print(f"Human: {probs[0][1]}")
|