|
--- |
|
language: |
|
- tw |
|
tags: |
|
- albert |
|
- classification |
|
license: afl-3.0 |
|
metrics: |
|
- Accuracy |
|
--- |
|
# Traditional Chinese news classification |
|
|
|
繁體中文新聞分類任務,使用ckiplab/albert-base-chinese預訓練模型,資料集只有2.6萬筆,做為課程的範例模型。 |
|
|
|
from transformers import BertTokenizer, AlbertForSequenceClassification |
|
model_path = "clhuang/albert-news-classification" |
|
model = AlbertForSequenceClassification.from_pretrained(model_path) |
|
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese") |
|
|
|
# get category probability |
|
def get_category_proba( text ): |
|
max_length = 250 |
|
# prepare token sequence |
|
inputs = tokenizer([text], padding=True, truncation=True, max_length=max_length, return_tensors="pt") |
|
# perform inference |
|
outputs = model(**inputs) |
|
# get output probabilities by doing softmax |
|
probs = outputs[0].softmax(1) |
|
|
|
# executing argmax function to get the candidate label index |
|
label_index = probs.argmax(dim=1)[0].tolist() # convert tensor to int |
|
# get the label name |
|
label = idx2cate[ label_index ] |
|
|
|
# get the label probability |
|
proba = round(float(probs.tolist()[0][label_index]),2) |
|
|
|
response = {'label': label, 'proba': proba} |
|
|
|
return response |
|
|
|
get_category_proba('俄羅斯2月24日入侵烏克蘭至今不到3個月,芬蘭已準備好扭轉奉行了75年的軍事不結盟政策,申請加入北約。芬蘭總理馬林昨天表示,「希望我們下星期能與瑞典一起提出申請」。') |
|
{'label': '國際', 'proba': 0.99} |