|
--- |
|
license: mit |
|
datasets: |
|
- fancyzhx/ag_news |
|
language: |
|
- en |
|
metrics: |
|
- accuracy |
|
base_model: |
|
- google-t5/t5-large |
|
pipeline_tag: text-classification |
|
tags: |
|
- ag |
|
- news |
|
- document |
|
- classification |
|
--- |
|
This model is finetuned using AG news dataset for 2 epochs using 120000 train samples and evaluated on the test set with below metrics. |
|
|
|
Test Loss: 0.1629 |
|
|
|
Accuracy: 0.9521 |
|
|
|
F1 Score: 0.9521 |
|
|
|
Precision: 0.9522 |
|
|
|
Recall: 0.9522 |
|
|
|
|
|
```python |
|
# Import necessary libraries |
|
import torch |
|
import torch.nn as nn |
|
from transformers import T5Tokenizer, T5ForConditionalGeneration |
|
|
|
# Set device |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
# Define the model class (same structure as used during training) |
|
class CustomT5Model(nn.Module): |
|
def __init__(self): |
|
super(CustomT5Model, self).__init__() |
|
self.t5 = T5ForConditionalGeneration.from_pretrained("t5-large") |
|
self.classifier = nn.Linear(1024, 4) # 4 classes for AG News |
|
|
|
def forward(self, input_ids, attention_mask=None): |
|
encoder_outputs = self.t5.encoder( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
return_dict=True |
|
) |
|
hidden_states = encoder_outputs.last_hidden_state # (batch_size, seq_len, hidden_dim) |
|
logits = self.classifier(hidden_states[:, 0, :]) # Use [CLS] token representation |
|
return logits |
|
|
|
# Initialize the model |
|
model = CustomT5Model().to(device) |
|
|
|
# Load the saved model weights from Hugging Face |
|
model_path = "https://huggingface.co/Vijayendra/T5-large-docClassification/resolve/main/best_model.pth" |
|
model.load_state_dict(torch.hub.load_state_dict_from_url(model_path, map_location=device)) |
|
model.eval() |
|
|
|
# Load the tokenizer |
|
tokenizer = T5Tokenizer.from_pretrained("t5-large") |
|
|
|
# Inference function |
|
def infer(model, tokenizer, text): |
|
model.eval() |
|
with torch.no_grad(): |
|
# Preprocess the input text |
|
inputs = tokenizer( |
|
[f"classify: {text}"], |
|
max_length=99, |
|
truncation=True, |
|
padding="max_length", |
|
return_tensors="pt" |
|
) |
|
input_ids = inputs["input_ids"].to(device) |
|
attention_mask = inputs["attention_mask"].to(device) |
|
|
|
# Get model predictions |
|
logits = model(input_ids=input_ids, attention_mask=attention_mask) |
|
preds = torch.argmax(logits, dim=-1) |
|
|
|
# Map class index to label |
|
label_map = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"} |
|
return label_map[preds.item()] |
|
|
|
# Example usage |
|
text = "NASA announces new mission to study asteroids" |
|
result = infer(model, tokenizer, text) |
|
print(f"Predicted category: {result}") |