added the id2labe
Browse files
app.py
CHANGED
@@ -54,6 +54,11 @@ def extract_predictions(outputs):
|
|
54 |
|
55 |
# a function that classifies text
|
56 |
|
|
|
|
|
|
|
|
|
|
|
57 |
def classify_text(text):
|
58 |
|
59 |
# Split text into segments using split_text
|
@@ -62,10 +67,7 @@ def classify_text(text):
|
|
62 |
# Initialize empty list for predictions
|
63 |
predictions = []
|
64 |
|
65 |
-
|
66 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
67 |
-
model = model.to(device)
|
68 |
-
|
69 |
# Loop through segments, process, and store predictions
|
70 |
for segment in segments:
|
71 |
inputs = tokenizer([segment], padding=True, return_tensors="pt")
|
@@ -77,13 +79,15 @@ def classify_text(text):
|
|
77 |
|
78 |
# Extract predictions for each segment
|
79 |
probs, preds = extract_predictions(outputs) # Define this function based on your model's output
|
80 |
-
|
81 |
# Append predictions for this segment
|
82 |
predictions.append({
|
83 |
"segment_text": segment,
|
84 |
-
"label":
|
85 |
-
"probability": probs[preds[0]] # Access probability for the predicted label
|
86 |
})
|
|
|
|
|
87 |
|
88 |
|
89 |
interface = gr.Interface(
|
|
|
54 |
|
55 |
# a function that classifies text
|
56 |
|
57 |
+
# Move device to GPU if available
|
58 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
59 |
+
model = model.to(device)
|
60 |
+
class_names=list(model.config.id2label.values())
|
61 |
+
|
62 |
def classify_text(text):
|
63 |
|
64 |
# Split text into segments using split_text
|
|
|
67 |
# Initialize empty list for predictions
|
68 |
predictions = []
|
69 |
|
70 |
+
|
|
|
|
|
|
|
71 |
# Loop through segments, process, and store predictions
|
72 |
for segment in segments:
|
73 |
inputs = tokenizer([segment], padding=True, return_tensors="pt")
|
|
|
79 |
|
80 |
# Extract predictions for each segment
|
81 |
probs, preds = extract_predictions(outputs) # Define this function based on your model's output
|
82 |
+
pred_label=class_names[preds[0].item()]
|
83 |
# Append predictions for this segment
|
84 |
predictions.append({
|
85 |
"segment_text": segment,
|
86 |
+
"label": pred_label, # Assuming single label prediction
|
87 |
+
"probability": probs[0][preds[0]].item() # Access probability for the predicted label
|
88 |
})
|
89 |
+
|
90 |
+
return predictions
|
91 |
|
92 |
|
93 |
interface = gr.Interface(
|