qgyd2021's picture
[update]add code
147e44c
raw
history blame
4.6 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import json
from allennlp.data.tokenizers.pretrained_transformer_tokenizer import PretrainedTransformerTokenizer
from allennlp.data.token_indexers.single_id_token_indexer import SingleIdTokenIndexer
from allennlp.data.vocabulary import Vocabulary
from allennlp.modules.text_field_embedders.basic_text_field_embedder import BasicTextFieldEmbedder
from allennlp.modules.token_embedders.embedding import Embedding
from allennlp.modules.seq2vec_encoders.cnn_encoder import CnnEncoder
from allennlp.models.archival import archive_model, load_archive
from allennlp_models.rc.modules.seq2seq_encoders.stacked_self_attention import StackedSelfAttentionEncoder
from allennlp.predictors.predictor import Predictor
from allennlp.predictors.text_classifier import TextClassifierPredictor
import gradio as gr
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from project_settings import project_path
from toolbox.allennlp_models.text_classifier.models.hierarchical_text_classifier import HierarchicalClassifier
from toolbox.allennlp_models.text_classifier.dataset_readers.hierarchical_classification_json import HierarchicalClassificationJsonReader
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--excel_file",
default=r"D:\Users\tianx\PycharmProjects\telemarketing_intent\data\excel\telemarketing_intent_vi.xlsx",
type=str,
)
parser.add_argument(
"--archive_file",
default=(project_path / "trained_models/telemarketing_intent_classification_vi").as_posix(),
type=str
)
parser.add_argument(
"--predictor_name",
default="text_classifier",
type=str
)
parser.add_argument(
"--top_k",
default=10,
type=int
)
parser.add_argument(
"--output_file",
default="intent_top_k.jsonl",
type=str
)
args = parser.parse_args()
return args
def main():
args = get_args()
archive = load_archive(archive_file=args.archive_file)
predictor = Predictor.from_archive(archive, predictor_name=args.predictor_name)
df = pd.read_excel(args.excel_file)
with open(args.output_file, "w", encoding="utf-8") as f:
for i, row in tqdm(df.iterrows(), total=len(df)):
if i < 26976:
continue
source = row["source"]
text = row["text"]
label0 = row["label0"]
label1 = row["label1"]
selected = row["selected"]
checked = row["checked"]
if pd.isna(source) or source is None:
source = None
if pd.isna(text) or text is None:
continue
text = str(text)
if pd.isna(label0) or label0 is None:
label0 = None
if pd.isna(label1) or label1 is None:
label1 = None
if pd.isna(selected) or selected is None:
selected = None
else:
try:
selected = int(selected)
except Exception:
print(type(selected))
selected = None
if pd.isna(checked) or checked is None:
checked = None
else:
try:
checked = int(checked)
except Exception:
print(type(checked))
checked = None
# print(text)
json_dict = {'sentence': text}
outputs = predictor.predict_json(
json_dict
)
probs = outputs["probs"]
arg_idx = np.argsort(probs)
arg_idx_top_k = arg_idx[-10:]
label_top_k = [
predictor._model.vocab.get_token_from_index(index=idx, namespace="labels").split("_")[-1] for idx in arg_idx_top_k
]
prob_top_k = [
str(round(probs[idx], 5)) for idx in arg_idx_top_k
]
row_ = {
"source": source,
"text": text,
"label0": label0,
"label1": label1,
"selected": selected,
"checked": checked,
"predict_label_top_k": ";".join(list(reversed(label_top_k))),
"predict_prob_top_k": ";".join(list(reversed(prob_top_k)))
}
row_ = json.dumps(row_, ensure_ascii=False)
f.write("{}\n".format(row_))
return
if __name__ == '__main__':
main()