Vijayendra's picture
Update README.md
96a411a verified
metadata
license: mit
datasets:
  - fancyzhx/ag_news
language:
  - en
metrics:
  - accuracy
base_model:
  - google-t5/t5-large
pipeline_tag: text-classification
tags:
  - ag
  - news
  - document
  - classification

This model is finetuned using AG news dataset for 2 epochs using 120000 train samples and evaluated on the test set with below metrics.

Test Loss: 0.1629

Accuracy: 0.9521

F1 Score: 0.9521

Precision: 0.9522

Recall: 0.9522

# Import necessary libraries
import torch
import torch.nn as nn
from transformers import T5Tokenizer, T5ForConditionalGeneration

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the model class (same structure as used during training)
class CustomT5Model(nn.Module):
    def __init__(self):
        super(CustomT5Model, self).__init__()
        self.t5 = T5ForConditionalGeneration.from_pretrained("t5-large")
        self.classifier = nn.Linear(1024, 4)  # 4 classes for AG News

    def forward(self, input_ids, attention_mask=None):
        encoder_outputs = self.t5.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        hidden_states = encoder_outputs.last_hidden_state  # (batch_size, seq_len, hidden_dim)
        logits = self.classifier(hidden_states[:, 0, :])  # Use [CLS] token representation
        return logits

# Initialize the model
model = CustomT5Model().to(device)

# Load the saved model weights from Hugging Face
model_path = "https://huggingface.co/Vijayendra/T5-large-docClassification/resolve/main/best_model.pth"
model.load_state_dict(torch.hub.load_state_dict_from_url(model_path, map_location=device))
model.eval()

# Load the tokenizer
tokenizer = T5Tokenizer.from_pretrained("t5-large")

# Inference function
def infer(model, tokenizer, text):
    model.eval()
    with torch.no_grad():
        # Preprocess the input text
        inputs = tokenizer(
            [f"classify: {text}"],
            max_length=99,
            truncation=True,
            padding="max_length",
            return_tensors="pt"
        )
        input_ids = inputs["input_ids"].to(device)
        attention_mask = inputs["attention_mask"].to(device)

        # Get model predictions
        logits = model(input_ids=input_ids, attention_mask=attention_mask)
        preds = torch.argmax(logits, dim=-1)

        # Map class index to label
        label_map = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
        return label_map[preds.item()]

# Example usage
text = "NASA announces new mission to study asteroids"
result = infer(model, tokenizer, text)
print(f"Predicted category: {result}")