File size: 1,879 Bytes
984b548 ff51298 984b548 ff51298 984b548 da4c49d 984b548 1920585 cdaa274 1920585 66c789c 1920585 66c789c 1920585 66c789c 1920585 66c789c 1920585 66c789c 1920585 71b2877 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
---
language:
- tw
tags:
- albert
- classification
license: afl-3.0
metrics:
- Accuracy
---
# Traditional Chinese news classification
繁體中文新聞分類任務,使用ckiplab/albert-base-chinese預訓練模型,資料集只有2.6萬筆,做為課程的範例模型。
from transformers import BertTokenizer, AlbertForSequenceClassification
model_path = "clhuang/albert-news-classification"
model = AlbertForSequenceClassification.from_pretrained(model_path)
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
# Category index
news_categories=['政治','科技','運動','證卷','產經','娛樂','生活','國際','社會','文化','兩岸']
idx2cate = { i : item for i, item in enumerate(news_categories)}
# get category probability
def get_category_proba( text ):
max_length = 250
# prepare token sequence
inputs = tokenizer([text], padding=True, truncation=True, max_length=max_length, return_tensors="pt")
# perform inference
outputs = model(**inputs)
# get output probabilities by doing softmax
probs = outputs[0].softmax(1)
# executing argmax function to get the candidate label index
label_index = probs.argmax(dim=1)[0].tolist() # convert tensor to int
# get the label name
label = idx2cate[ label_index ]
# get the label probability
proba = round(float(probs.tolist()[0][label_index]),2)
response = {'label': label, 'proba': proba}
return response
get_category_proba('俄羅斯2月24日入侵烏克蘭至今不到3個月,芬蘭已準備好扭轉奉行了75年的軍事不結盟政策,申請加入北約。芬蘭總理馬林昨天表示,「希望我們下星期能與瑞典一起提出申請」。')
{'label': '國際', 'proba': 0.99} |