Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -72,8 +72,12 @@ def get_prediction(inputs):
|
|
72 |
outputs = model(**inputs)
|
73 |
logits = outputs.last_hidden_state[:, 0, :] # 取CLS标记的输出进行分类
|
74 |
pred_prob = torch.softmax(logits, dim=1)
|
75 |
-
pred = torch.argmax(pred_prob, dim=1)
|
76 |
-
|
|
|
|
|
|
|
|
|
77 |
|
78 |
# vectorizer= nltk_u.vectorizer()
|
79 |
# vectorizer.fit(train_data.text)
|
|
|
72 |
outputs = model(**inputs)
|
73 |
logits = outputs.last_hidden_state[:, 0, :] # 取CLS标记的输出进行分类
|
74 |
pred_prob = torch.softmax(logits, dim=1)
|
75 |
+
pred = torch.argmax(pred_prob, dim=1).item()
|
76 |
+
if pred in class_names:
|
77 |
+
return class_names[pred]
|
78 |
+
else:
|
79 |
+
print(f"Warning: Prediction index {pred} not found in class_names.")
|
80 |
+
return "Unknown" # 或者其他默认的响应
|
81 |
|
82 |
# vectorizer= nltk_u.vectorizer()
|
83 |
# vectorizer.fit(train_data.text)
|