#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse import os import time from allennlp.data.tokenizers.pretrained_transformer_tokenizer import PretrainedTransformerTokenizer from allennlp.data.token_indexers.single_id_token_indexer import SingleIdTokenIndexer from allennlp.data.vocabulary import Vocabulary from allennlp.modules.text_field_embedders.basic_text_field_embedder import BasicTextFieldEmbedder from allennlp.modules.token_embedders.embedding import Embedding from allennlp.modules.seq2vec_encoders.cnn_encoder import CnnEncoder from allennlp.models.archival import archive_model, load_archive from allennlp_models.rc.modules.seq2seq_encoders.stacked_self_attention import StackedSelfAttentionEncoder from allennlp.predictors.predictor import Predictor from allennlp.predictors.text_classifier import TextClassifierPredictor import gradio as gr import torch from project_settings import project_path from toolbox.allennlp_models.text_classifier.models.hierarchical_text_classifier import HierarchicalClassifier from toolbox.allennlp_models.text_classifier.dataset_readers.hierarchical_classification_json import HierarchicalClassificationJsonReader def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--cn_archive_file", default=(project_path / "trained_models/telemarketing_intent_classification_cn").as_posix(), type=str ) parser.add_argument( "--en_archive_file", default=(project_path / "trained_models/telemarketing_intent_classification_en").as_posix(), type=str ) parser.add_argument( "--jp_archive_file", default=(project_path / "trained_models/telemarketing_intent_classification_jp").as_posix(), type=str ) parser.add_argument( "--vi_archive_file", default=(project_path / "trained_models/telemarketing_intent_classification_vi").as_posix(), type=str ) parser.add_argument( "--predictor_name", default="text_classifier", type=str ) args = parser.parse_args() return args def main(): args = get_args() cn_archive = load_archive(archive_file=args.cn_archive_file) cn_predictor = Predictor.from_archive(cn_archive, predictor_name=args.predictor_name) en_archive = load_archive(archive_file=args.en_archive_file) en_predictor = Predictor.from_archive(en_archive, predictor_name=args.predictor_name) jp_archive = load_archive(archive_file=args.jp_archive_file) jp_predictor = Predictor.from_archive(jp_archive, predictor_name=args.predictor_name) vi_archive = load_archive(archive_file=args.vi_archive_file) vi_predictor = Predictor.from_archive(vi_archive, predictor_name=args.predictor_name) predictor_map = { "chinese": cn_predictor, "english": en_predictor, "japanese": jp_predictor, "vietnamese": vi_predictor, } def fn(text: str, language: str): predictor = predictor_map.get(language, cn_predictor) json_dict = {'sentence': text} outputs = predictor.predict_json( json_dict ) outputs = predictor._model.decode(outputs) label = outputs['label'][0] prob = outputs['prob'][0] prob = round(prob, 4) return label, prob description = """ 电销场景意图识别. 语言: 汉语, 英语, 日语, 越南语. 数据集是私有的. model: selfattention-cnn dataset: telemarketing_intent (https://huggingface.co/datasets/qgyd2021/telemarketing_intent) accuracy: chinese: 0.8002 english: 0.7011 japanese: 0.8154 vietnamese: 0.8168 """ demo = gr.Interface( fn=fn, inputs=[ gr.Text(label="text"), gr.Dropdown( choices=list(sorted(predictor_map.keys())), label="language" ) ], outputs=[gr.Text(label="intent"), gr.Number(label="prob")], examples=[ ["你找谁", "chinese"], ["你是谁啊", "chinese"], ["不好意思我现在很忙", "chinese"], ["对不起, 不需要哈", "chinese"], ["u have got the wrong number", "english"], ["sure, thank a lot", "english"], ["please leave your message for 95688496", "english"], ["yes well", "english"], ["失礼の", "japanese"], ["ビートいう発表の後に、お名前とご用件をお話ください。", "japanese"], ["わかんない。", "japanese"], ["に出ることができません", "japanese"], ["À không phải em nha.", "vietnamese"], ["Dạ nhầm số rồi ạ?", "vietnamese"], ["Ừ, cảm ơn em nhá.", "vietnamese"], ["Không, chị không có tiền.", "vietnamese"], ], examples_per_page=50, title="Telemarketing Intent Classification", description=description, ) demo.launch() return if __name__ == '__main__': main()