|
import gradio as gr |
|
import torch |
|
from transformers import AutoTokenizer, AutoModel |
|
from PIL import Image |
|
from torchvision import transforms |
|
import json |
|
from torch import nn |
|
from typing import Literal |
|
|
|
|
|
class MultimodalClassifier(nn.Module): |
|
def __init__( |
|
self, |
|
text_encoder_id_or_path: str, |
|
image_encoder_id_or_path: str, |
|
projection_dim: int, |
|
fusion_method: Literal["concat", "align", "cosine_similarity"] = "concat", |
|
proj_dropout: float = 0.1, |
|
fusion_dropout: float = 0.1, |
|
num_classes: int = 1, |
|
) -> None: |
|
super().__init__() |
|
|
|
self.fusion_method = fusion_method |
|
self.projection_dim = projection_dim |
|
self.num_classes = num_classes |
|
|
|
|
|
self.text_encoder = AutoModel.from_pretrained(text_encoder_id_or_path) |
|
self.text_projection = nn.Sequential( |
|
nn.Linear(self.text_encoder.config.hidden_size, self.projection_dim), |
|
nn.Dropout(proj_dropout), |
|
) |
|
|
|
|
|
self.image_encoder = AutoModel.from_pretrained(image_encoder_id_or_path, trust_remote_code=True) |
|
self.image_encoder.classifier = nn.Identity() |
|
self.image_projection = nn.Sequential( |
|
nn.Linear(512, self.projection_dim), |
|
nn.Dropout(proj_dropout), |
|
) |
|
|
|
|
|
fusion_input_dim = self.projection_dim * 2 if fusion_method == "concat" else self.projection_dim |
|
self.fusion_layer = nn.Sequential( |
|
nn.Dropout(fusion_dropout), |
|
nn.Linear(fusion_input_dim, self.projection_dim), |
|
nn.GELU(), |
|
nn.Dropout(fusion_dropout), |
|
) |
|
|
|
|
|
self.classifier = nn.Linear(self.projection_dim, self.num_classes) |
|
|
|
def forward(self, pixel_values: torch.Tensor, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: |
|
|
|
full_text_features = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True).last_hidden_state |
|
full_text_features = full_text_features[:, 0, :] |
|
full_text_features = self.text_projection(full_text_features) |
|
|
|
|
|
resnet_image_features = self.image_encoder(pixel_values=pixel_values).last_hidden_state |
|
resnet_image_features = resnet_image_features.mean(dim=[-2, -1]) |
|
resnet_image_features = self.image_projection(resnet_image_features) |
|
|
|
|
|
if self.fusion_method == "concat": |
|
fused_features = torch.cat([full_text_features, resnet_image_features], dim=-1) |
|
else: |
|
fused_features = full_text_features * resnet_image_features |
|
|
|
|
|
fused_features = self.fusion_layer(fused_features) |
|
classification_output = self.classifier(fused_features) |
|
return classification_output |
|
|
|
|
|
def load_model(): |
|
with open("config.json", "r") as f: |
|
config = json.load(f) |
|
|
|
model = MultimodalClassifier( |
|
text_encoder_id_or_path=config["text_encoder_id_or_path"], |
|
image_encoder_id_or_path="microsoft/resnet-34", |
|
projection_dim=config["projection_dim"], |
|
fusion_method=config["fusion_method"], |
|
proj_dropout=config["proj_dropout"], |
|
fusion_dropout=config["fusion_dropout"], |
|
num_classes=config["num_classes"] |
|
) |
|
|
|
checkpoint = torch.load("model_weights.pth", map_location=torch.device('cpu')) |
|
model.load_state_dict(checkpoint, strict=False) |
|
|
|
return model |
|
|
|
|
|
model = load_model() |
|
model.eval() |
|
text_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") |
|
|
|
|
|
image_transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
|
|
|
|
def predict(image: Image.Image, text: str) -> str: |
|
|
|
text_inputs = text_tokenizer( |
|
text, |
|
return_tensors="pt", |
|
padding="max_length", |
|
truncation=True, |
|
max_length=512 |
|
) |
|
|
|
|
|
image_input = image_transform(image).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
classification_output = model( |
|
pixel_values=image_input, |
|
input_ids=text_inputs["input_ids"], |
|
attention_mask=text_inputs["attention_mask"] |
|
) |
|
predicted_class = torch.sigmoid(classification_output).round().item() |
|
|
|
return "Fake News" if predicted_class == 1 else "Real News" |
|
|
|
|
|
interface = gr.Interface( |
|
fn=predict, |
|
inputs=[ |
|
gr.Image(type="pil", label="Upload Related Image"), |
|
gr.Textbox(lines=2, placeholder="Enter news text for classification...", label="Input Text") |
|
], |
|
outputs=gr.Label(label="Prediction"), |
|
title="Fake News Detector", |
|
description="Upload an image and provide text to classify the news as 'Fake' or 'Real'." |
|
) |
|
|
|
interface.launch() |
|
|