Spaces:
Runtime error
Runtime error
import numpy as np | |
import os | |
import gradio as gr | |
import torch | |
import torch.nn as nn | |
from PIL import Image | |
from transformers import CLIPModel, AutoModel | |
from typing import Optional | |
from safetensors.torch import load_model | |
os.environ["WANDB_DISABLED"] = "true" | |
from datasets import load_dataset, load_metric | |
from transformers import ( | |
AutoConfig, | |
AutoModelForSequenceClassification, | |
AutoTokenizer, | |
logging, | |
) | |
class VisionTextDualEncoderModel(nn.Module): | |
def __init__(self, num_classes): | |
super(VisionTextDualEncoderModel, self).__init__() | |
# Load the XLM-RoBERTa model | |
self.text_encoder = AutoModel.from_pretrained("cardiffnlp/twitter-xlm-roberta-base-sentiment-multilingual") | |
# Define your vision model (e.g., using torchvision) | |
self.vision_encoder = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
vision_output_dim = self.vision_encoder.config.vision_config.hidden_size | |
# Combine the modalities | |
self.fc = nn.Linear( | |
self.text_encoder.config.hidden_size + vision_output_dim, num_classes | |
) | |
def forward( | |
self, | |
input_ids: Optional[torch.LongTensor] = None, | |
pixel_values: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
return_loss: Optional[bool] = None, | |
token_type_ids: Optional[torch.LongTensor] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
labels: Optional[torch.LongTensor] = None, | |
): | |
# Encode text inputs | |
text_outputs = self.text_encoder( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
token_type_ids=token_type_ids, | |
position_ids=position_ids, | |
).pooler_output | |
# Encode vision inputs | |
vision_outputs = self.vision_encoder.vision_model( | |
pixel_values=pixel_values, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
# Concatenate text and vision features | |
combined_features = torch.cat( | |
(text_outputs, vision_outputs.pooler_output), dim=1 | |
) | |
# Forward through a linear layer for classification | |
logits = self.fc(combined_features) | |
return {"logits": logits} | |
id2label = {0: "negative", 1: "neutral", 2: "positive"} | |
label2id = {"negative": 0, "neutral": 1, "positive": 2} | |
tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-xlm-roberta-base-sentiment-multilingual") | |
model = VisionTextDualEncoderModel(num_classes=3) | |
config = model.vision_text_model.config | |
# https://huggingface.co/FFZG-cleopatra/M2SA/blob/main/model.safetensors | |
sf_filename = hf_hub_download("FFZG-cleopatra/M2SA", filename="model.safetensors") | |
load_model(model,"model.safetensors") # model.load_state_dict(torch.load(model_args.model_name_or_path+"-finetuned/pytorch_model.bin")) | |
model = AutoModelForSequenceClassification.from_pretrained( | |
"FFZG-cleopatra/M2SA", | |
num_labels=3, id2label=id2label, | |
label2id=label2id | |
) | |
def predict_sentiment(text, image): | |
print(text, image) | |
text_inputs = tokenizer( | |
text, | |
max_length=512, | |
padding="max_length", | |
truncation=True, | |
) | |
image_transformations = Transform( | |
config.vision_config.image_size, | |
image_processor.image_mean, | |
image_processor.image_std, | |
) | |
image_transformations = torch.jit.script(image_transformations) | |
image = image_transformations(image) | |
model_input = { | |
"input_ids" : text_inputs.input_ids, | |
"pixel_values":image | |
"attention_mask" : text_inputs.attention_mask, | |
} | |
prediction = None | |
with torch.no_grad(): | |
prediction = model(model_input) | |
print(prediction) | |
return prediction | |
interface = gr.Interface( | |
fn=lambda text, image: predict_sentiment(text, image), | |
inputs=[gr.inputs.Textbox(),gr.inputs.Image(shape=(224, 224))], | |
outputs=['text'], | |
title='Multilingual-Multimodal-Sentiment-Analysis', | |
examples= ["I love tea","I hate coffee"], | |
description='Get the positive/neutral/negative sentiment for the given input.' | |
) | |
interface.launch(inline = False) | |