RingoDingo commited on
Commit
6588ad6
·
1 Parent(s): 6bd4036

Upload model_architecture.py

Browse files
Files changed (1) hide show
  1. model_architecture.py +76 -0
model_architecture.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy
3
+ from PIL import Image
4
+ from torchvision.transforms import ToTensor
5
+ from transformers import ViTModel, ViTFeatureExtractor
6
+ from transformers.modeling_outputs import SequenceClassifierOutput
7
+ import torch.nn as nn
8
+
9
+ feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-large-patch32-384')
10
+ class ViTForImageClassification(nn.Module):
11
+ def __init__(self, num_labels=2):
12
+ super(ViTForImageClassification, self).__init__()
13
+ self.vit = ViTModel.from_pretrained('google/vit-large-patch32-384')
14
+ self.dropout = nn.Dropout(0.1)
15
+ self.classifier = nn.Linear(self.vit.config.hidden_size, num_labels)
16
+ self.num_labels = num_labels
17
+
18
+ def forward(self, pixel_values, labels=None):
19
+ outputs = self.vit(pixel_values=pixel_values)
20
+ output = self.dropout(outputs.last_hidden_state[:,0])
21
+ logits = self.classifier(output)
22
+
23
+ if labels is not None:
24
+ loss_fct = nn.CrossEntropyLoss()
25
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
26
+ return SequenceClassifierOutput(
27
+ loss=loss,
28
+ logits=logits,
29
+ )
30
+ else:
31
+ return logits
32
+
33
+ def preprocess_image(image, desired_size=384):
34
+ im = image
35
+
36
+ # Resize and pad the image
37
+ old_size = im.size
38
+ ratio = float(desired_size) / max(old_size)
39
+ new_size = tuple([int(x*ratio) for x in old_size])
40
+ im = im.resize(new_size)
41
+
42
+ # Create a new image and paste the resized on it
43
+ new_im = Image.new("RGB", (desired_size, desired_size), "white")
44
+ new_im.paste(im, ((desired_size-new_size[0])//2, (desired_size-new_size[1])//2))
45
+ return new_im
46
+
47
+ def predict_image(image, model, feature_extractor):
48
+ # Ensure model is in eval mode
49
+ model.eval()
50
+
51
+ # Convert image to tensor
52
+ transform = ToTensor()
53
+ input_tensor = transform(image)
54
+ input_tensor = torch.tensor(numpy.array(feature_extractor(input_tensor)['pixel_values']))
55
+
56
+ # Move tensors to the right device
57
+ input_tensor = input_tensor.cuda()
58
+
59
+ # Forward pass of the image through the model
60
+ output = model(input_tensor)
61
+
62
+ # Convert model output to probabilities using softmax
63
+ probabilities = torch.nn.functional.softmax(output, dim=1)
64
+
65
+ return probabilities.cpu().detach().numpy()
66
+
67
+
68
+ model = ViTForImageClassification(num_labels=2)
69
+ model.load_state_dict(torch.load("./AID96k_E15_384.pth"))
70
+ model.cuda()
71
+ model.eval()
72
+ img = Image.open("test.png")
73
+ img = preprocess_image(img)
74
+ probs = predict_image(img, model, feature_extractor)
75
+ print(f"AI: {probs[0][0]}")
76
+ print(f"Human: {probs[0][1]}")