|
--- |
|
license: apache-2.0 |
|
datasets: |
|
- cnmoro/QuestionClassification |
|
tags: |
|
- classification |
|
- questioning |
|
- directed |
|
- generic |
|
language: |
|
- en |
|
- pt |
|
library_name: transformers |
|
pipeline_tag: text-classification |
|
widget: |
|
- text: "What is the summary of the text?" |
|
--- |
|
|
|
A finetuned version of prajjwal1/bert-tiny. |
|
|
|
The goal is to classify questions into "Directed" or "Generic". |
|
|
|
If a question is not directed, we would change the actions we perform on a RAG pipeline (if it is generic, semantic search wouldn't be useful directly; e.g. asking for a summary). |
|
|
|
(Class 0 is Generic; Class 1 is Directed) |
|
|
|
The accuracy on the training dataset is around 87.5% |
|
|
|
```python |
|
from transformers import BertForSequenceClassification, BertTokenizerFast |
|
import torch |
|
|
|
# Load the model and tokenizer |
|
model = BertForSequenceClassification.from_pretrained("cnmoro/bert-tiny-question-classifier") |
|
tokenizer = BertTokenizerFast.from_pretrained("cnmoro/bert-tiny-question-classifier") |
|
|
|
def is_question_generic(question): |
|
# Tokenize the sentence and convert to PyTorch tensors |
|
inputs = tokenizer( |
|
question.lower(), |
|
truncation=True, |
|
padding=True, |
|
return_tensors="pt", |
|
max_length=512 |
|
) |
|
|
|
# Get the model's predictions |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
# Extract the prediction |
|
predictions = outputs.logits |
|
predicted_class = torch.argmax(predictions).item() |
|
|
|
return int(predicted_class) == 0 |
|
``` |