|
import torch |
|
import numpy |
|
from PIL import Image |
|
from torchvision.transforms import ToTensor |
|
from transformers import ViTModel, ViTFeatureExtractor |
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
import torch.nn as nn |
|
|
|
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-large-patch32-384') |
|
class ViTForImageClassification(nn.Module): |
|
def __init__(self, num_labels=2): |
|
super(ViTForImageClassification, self).__init__() |
|
self.vit = ViTModel.from_pretrained('google/vit-large-patch32-384') |
|
self.dropout = nn.Dropout(0.1) |
|
self.classifier = nn.Linear(self.vit.config.hidden_size, num_labels) |
|
self.num_labels = num_labels |
|
|
|
def forward(self, pixel_values, labels=None): |
|
outputs = self.vit(pixel_values=pixel_values) |
|
output = self.dropout(outputs.last_hidden_state[:,0]) |
|
logits = self.classifier(output) |
|
|
|
if labels is not None: |
|
loss_fct = nn.CrossEntropyLoss() |
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
return SequenceClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
) |
|
else: |
|
return logits |
|
|
|
def preprocess_image(image, desired_size=384): |
|
im = image |
|
|
|
|
|
old_size = im.size |
|
ratio = float(desired_size) / max(old_size) |
|
new_size = tuple([int(x*ratio) for x in old_size]) |
|
im = im.resize(new_size) |
|
|
|
|
|
new_im = Image.new("RGB", (desired_size, desired_size), "white") |
|
new_im.paste(im, ((desired_size-new_size[0])//2, (desired_size-new_size[1])//2)) |
|
return new_im |
|
|
|
def predict_image(image, model, feature_extractor): |
|
|
|
model.eval() |
|
|
|
|
|
transform = ToTensor() |
|
input_tensor = transform(image) |
|
input_tensor = torch.tensor(numpy.array(feature_extractor(input_tensor)['pixel_values'])) |
|
|
|
|
|
input_tensor = input_tensor.cuda() |
|
|
|
|
|
output = model(input_tensor) |
|
|
|
|
|
probabilities = torch.nn.functional.softmax(output, dim=1) |
|
|
|
return probabilities.cpu().detach().numpy() |
|
|
|
|
|
model = ViTForImageClassification(num_labels=2) |
|
model.load_state_dict(torch.load("./AID96k_E15_384.pth")) |
|
model.cuda() |
|
model.eval() |
|
img = Image.open("test.png") |
|
img = preprocess_image(img) |
|
probs = predict_image(img, model, feature_extractor) |
|
print(f"AI: {probs[0][0]}") |
|
print(f"Human: {probs[0][1]}") |
|
|