import gradio as gr import warnings warnings.filterwarnings("ignore") import pandas as pd import numpy as np import faiss import ast import torch.nn.functional as F import torch from transformers import AutoModel, AutoTokenizer Encoding_model = 'jinaai/jina-embeddings-v2-base-zh' model = AutoModel.from_pretrained(Encoding_model, trust_remote_code=True, torch_dtype=torch.bfloat16) model#.to("cuda") similarity_model = 'Alibaba-NLP/gte-multilingual-base' similarity_tokenizer = AutoTokenizer.from_pretrained(similarity_model) similarity_model = AutoModel.from_pretrained(similarity_model, trust_remote_code=True)#.to("cuda") def get_not_empty_data(df,x_column="text",y_column="label"): df = df[df[y_column] != "[]"].reset_index(drop=True) res_dict = {} for idx in df.index: if df.loc[idx,x_column] not in res_dict: res_dict[df.loc[idx,x_column]] = ast.literal_eval(df.loc[idx,y_column]) else: res_dict[df.loc[idx,x_column]] += ast.literal_eval(df.loc[idx,y_column]) res_dict = {k:list(set(v)) for k,v in res_dict.items()} df_dict = pd.DataFrame({"x":res_dict.keys(),"y":res_dict.values()}) return df_dict data_all = pd.read_excel("data_Excel_format.xlsx") df_dict_all = get_not_empty_data(data_all) x_dict = df_dict_all["x"].values y_dict = df_dict_all["y"].values def calc_scores(x): return (x[:1] @ x[1:].T) def get_idxs(threshold,max_len,arr): res = np.where(arr >= threshold)[0] if len(res)= threshold)[0] if len(index_of_index)>=max_len: index_of_index = index_of_index[np.argsort(-index[index_of_index])][:3] if len(index_of_index)==0: return {},[] res_index = index[index_of_index] res = merge_set_to_list([set(i) for i in y_dict[res_index]]) return res,x_dict[res_index] # vec = np.empty(shape=[0,768],dtype="float32") # bsize = 256 # with torch.no_grad(): # for i in range(0,len(x),bsize): # tmp = model.encode(x[i:i+bsize]) # vec = np.concatenate([vec,tmp]) # index = faiss.IndexFlatIP(768) # faiss.normalize_L2(vec) # index.add(vec) # faiss.write_index(index,"all_index.faiss") index = faiss.read_index("all_index.faiss") def predict_label(x,threshold=0.85,n_nearest=10,max_result_len=3): bsize=1 y_pred = [] with torch.no_grad(): for i in range(0,len(x),bsize): sentences = x[i:i+bsize] vec = model.encode(sentences) faiss.normalize_L2(vec) scores, indexes = index.search(vec,n_nearest) x_pred = np.array([[sentences[j]]+s.tolist() for j,s in enumerate(x_dict[indexes])]) batch_dict = similarity_tokenizer(x_pred.flatten().tolist(), max_length=768, padding=True, truncation=True, return_tensors='pt')#.to("cuda") outputs = similarity_model(**batch_dict) dimension=768 embeddings = outputs.last_hidden_state[:, 0][:dimension] embeddings = F.normalize(embeddings, p=2, dim=1) embeddings = embeddings.view(len(x_pred),n_nearest+1,dimension).detach().cpu().numpy() scores = [calc_scores(embeddings[b]) for b in range(embeddings.shape[0])] pred = [get_predict_result(indexes[k],scores[k],threshold=threshold,max_len=max_result_len) for k in range(len(scores))] y_pred.append([i[0] for i in pred]) return y_pred CSS_Content = """ 红色字体:潜在风险
蓝色字体:权限获取
紫色字体:数据收集
绿色字体:数据、权限管理
棕色字体:共享、委托、转让、公开(披露)
""" color_dict = {"潜在风险":"red", "权限获取":"blue", "数据收集":"purple", "数据、权限管理":"green", "共享、委托、转让、公开(披露)":"brown" } def generate_HTML(text,threshold=0.85,n_nearest=10,max_result_len=3): sentences = text.split("\n") sentences = [i for i in map(lambda x:x.split("。"),sentences)] res = CSS_Content for paragraph in sentences: tmp_res = [] pred_label = predict_label(paragraph,threshold,n_nearest,max_result_len) for i,x in enumerate(pred_label): pre = "0: for j in color_dict.keys(): #color dict重要性递减,所以只取第一个标签的颜色 if j in x[0]: pre += f' style="color: {color_dict[j]};line-height:1;"' break tmp_res.append(pre+">"+paragraph[i]+"") res += "。".join(tmp_res) res += "
" return res with gr.Blocks() as demo: with gr.Row(): input_text = gr.Textbox(lines=25,label="输入") with gr.Row(): threshold = gr.Slider(minimum=0.5,maximum=0.85,value=0.75,step=0.05,interactive=True,label="相似度阈值") n_nearest = gr.Slider(minimum=3,maximum=10,value=10,step=1,interactive=True,label="粗筛语句数量") max_result_len = gr.Slider(minimum=1,maximum=5,value=3,step=1,interactive=True,label="精筛语句数量") with gr.Row(): submit_button = gr.Button("检测") with gr.Row(): output_text = gr.HTML(CSS_Content) output_text.elem_id="custom_id" submit_button.click(fn=generate_HTML, inputs=[input_text,threshold,n_nearest,max_result_len], outputs=output_text) demo.launch()