|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import gradio as gr |
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
|
import torch |
|
import transformers |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("nebiyu29/fintunned-v2-roberta_GA") |
|
model = AutoModelForSequenceClassification.from_pretrained("nebiyu29/fintunned-v2-roberta_GA") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def split_text(text): |
|
|
|
print("going to split the text") |
|
|
|
tokens = tokenizer.tokenize(text) |
|
|
|
segments = [] |
|
|
|
current_segment = [] |
|
|
|
token_count = 0 |
|
|
|
for token in tokens: |
|
|
|
current_segment.append(token) |
|
|
|
token_count += 1 |
|
|
|
if token_count == 512 or token == tokens[-1]: |
|
|
|
segments.append(tokenizer.convert_tokens_to_string(current_segment)) |
|
|
|
current_segment = [] |
|
token_count = 0 |
|
|
|
return segments |
|
|
|
|
|
def extract_predictions(outputs): |
|
|
|
logits = outputs.logits |
|
probs = logits.softmax(dim=1) |
|
preds = torch.argmax(probs, dim=1) |
|
return probs, preds |
|
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = model.to(device) |
|
class_names=list(model.config.id2label.values()) |
|
|
|
def classify_text(text): |
|
|
|
|
|
segments = split_text(text) |
|
|
|
|
|
predictions = [] |
|
|
|
|
|
|
|
for segment in segments: |
|
inputs = tokenizer([segment], padding=True, return_tensors="pt") |
|
input_ids = inputs["input_ids"].to(device) |
|
attention_mask = inputs["attention_mask"].to(device) |
|
|
|
with torch.no_grad(): |
|
outputs = model(input_ids, attention_mask=attention_mask) |
|
|
|
|
|
probs, preds = extract_predictions(outputs) |
|
pred_label=class_names[preds[0].item()] |
|
|
|
predictions.append({ |
|
"segment_text": segment, |
|
"label": pred_label, |
|
"probability": probs[0][preds[0]].item() |
|
}) |
|
|
|
return predictions |
|
|
|
|
|
interface = gr.Interface( |
|
fn=classify_text, |
|
inputs="text", |
|
outputs="text", |
|
title="Text Classification Demo", |
|
description="Enter some text, and the model will classify it.", |
|
) |
|
|
|
|
|
|