AI_detector / model_architecture.py
RingoDingo's picture
Upload model_architecture.py
6588ad6
import torch
import numpy
from PIL import Image
from torchvision.transforms import ToTensor
from transformers import ViTModel, ViTFeatureExtractor
from transformers.modeling_outputs import SequenceClassifierOutput
import torch.nn as nn
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-large-patch32-384')
class ViTForImageClassification(nn.Module):
def __init__(self, num_labels=2):
super(ViTForImageClassification, self).__init__()
self.vit = ViTModel.from_pretrained('google/vit-large-patch32-384')
self.dropout = nn.Dropout(0.1)
self.classifier = nn.Linear(self.vit.config.hidden_size, num_labels)
self.num_labels = num_labels
def forward(self, pixel_values, labels=None):
outputs = self.vit(pixel_values=pixel_values)
output = self.dropout(outputs.last_hidden_state[:,0])
logits = self.classifier(output)
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return SequenceClassifierOutput(
loss=loss,
logits=logits,
)
else:
return logits
def preprocess_image(image, desired_size=384):
im = image
# Resize and pad the image
old_size = im.size
ratio = float(desired_size) / max(old_size)
new_size = tuple([int(x*ratio) for x in old_size])
im = im.resize(new_size)
# Create a new image and paste the resized on it
new_im = Image.new("RGB", (desired_size, desired_size), "white")
new_im.paste(im, ((desired_size-new_size[0])//2, (desired_size-new_size[1])//2))
return new_im
def predict_image(image, model, feature_extractor):
# Ensure model is in eval mode
model.eval()
# Convert image to tensor
transform = ToTensor()
input_tensor = transform(image)
input_tensor = torch.tensor(numpy.array(feature_extractor(input_tensor)['pixel_values']))
# Move tensors to the right device
input_tensor = input_tensor.cuda()
# Forward pass of the image through the model
output = model(input_tensor)
# Convert model output to probabilities using softmax
probabilities = torch.nn.functional.softmax(output, dim=1)
return probabilities.cpu().detach().numpy()
model = ViTForImageClassification(num_labels=2)
model.load_state_dict(torch.load("./AID96k_E15_384.pth"))
model.cuda()
model.eval()
img = Image.open("test.png")
img = preprocess_image(img)
probs = predict_image(img, model, feature_extractor)
print(f"AI: {probs[0][0]}")
print(f"Human: {probs[0][1]}")