--- license: mit datasets: - fancyzhx/ag_news language: - en metrics: - accuracy base_model: - google-t5/t5-large pipeline_tag: text-classification tags: - ag - news - document - classification --- This model is finetuned using AG news dataset for 2 epochs using 120000 train samples and evaluated on the test set with below metrics. Test Loss: 0.1629 Accuracy: 0.9521 F1 Score: 0.9521 Precision: 0.9522 Recall: 0.9522 ```python # Import necessary libraries import torch import torch.nn as nn from transformers import T5Tokenizer, T5ForConditionalGeneration # Set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Define the model class (same structure as used during training) class CustomT5Model(nn.Module): def __init__(self): super(CustomT5Model, self).__init__() self.t5 = T5ForConditionalGeneration.from_pretrained("t5-large") self.classifier = nn.Linear(1024, 4) # 4 classes for AG News def forward(self, input_ids, attention_mask=None): encoder_outputs = self.t5.encoder( input_ids=input_ids, attention_mask=attention_mask, return_dict=True ) hidden_states = encoder_outputs.last_hidden_state # (batch_size, seq_len, hidden_dim) logits = self.classifier(hidden_states[:, 0, :]) # Use [CLS] token representation return logits # Initialize the model model = CustomT5Model().to(device) # Load the saved model weights from Hugging Face model_path = "https://huggingface.co/Vijayendra/T5-large-docClassification/resolve/main/best_model.pth" model.load_state_dict(torch.hub.load_state_dict_from_url(model_path, map_location=device)) model.eval() # Load the tokenizer tokenizer = T5Tokenizer.from_pretrained("t5-large") # Inference function def infer(model, tokenizer, text): model.eval() with torch.no_grad(): # Preprocess the input text inputs = tokenizer( [f"classify: {text}"], max_length=99, truncation=True, padding="max_length", return_tensors="pt" ) input_ids = inputs["input_ids"].to(device) attention_mask = inputs["attention_mask"].to(device) # Get model predictions logits = model(input_ids=input_ids, attention_mask=attention_mask) preds = torch.argmax(logits, dim=-1) # Map class index to label label_map = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"} return label_map[preds.item()] # Example usage text = "NASA announces new mission to study asteroids" result = infer(model, tokenizer, text) print(f"Predicted category: {result}")