File size: 3,621 Bytes
7a5da00
 
 
 
 
7cab2f9
7a5da00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gradio as gr
from scipy.spatial.distance import cosine
from transformers import AutoModel, AutoTokenizer
from argparse import Namespace
import torch
from tsne import TSNE_Plot

tokenizer = AutoTokenizer.from_pretrained("silk-road/luotuo-bert")
model_args = Namespace(do_mlm=None,
                       pooler_type="cls",
                       temp=0.05,
                       mlp_only_train=False,
                       init_embeddings_model=None)
model = AutoModel.from_pretrained("silk-road/luotuo-bert",
                  trust_remote_code=True,
                  model_args=model_args)

def divide_str(s, sep=['\n', '.', '。']):
    mid_len = len(s) // 2  # 中心点位置
    best_sep_pos = len(s) + 1  # 最接近中心点的分隔符位置
    best_sep = None  # 最接近中心点的分隔符
    for curr_sep in sep:
        sep_pos = s.rfind(curr_sep, 0, mid_len)  # 从中心点往左找分隔符
        if sep_pos > 0 and abs(sep_pos - mid_len) < abs(best_sep_pos - mid_len):
            best_sep_pos = sep_pos
            best_sep = curr_sep
    if not best_sep:  # 没有找到分隔符
        return s, ''
    return s[:best_sep_pos + 1], s[best_sep_pos + 1:]

def strong_divide( s ):
  left, right = divide_str(s)

  if right != '':
    return left, right

  whole_sep = ['\n', '.', ',', '、', ';', ',', ';',\
               ':', '!', '?', '(', ')', '”', '“', \
               '’', '‘', '[', ']', '{', '}', '<', '>', \
               '/', '''\''', '|', '-', '=', '+', '*', '%', \
               '$', '''#''', '@', '&', '^', '_', '`', '~',\
               '·', '…']
  left, right = divide_str(s, sep = whole_sep )

  if right != '':
    return left, right
  
  mid_len = len(s) // 2
  return s[:mid_len], s[mid_len:]

def generate_image(text_input):
    # 将输入的文本按行分割并保存到列表中
    text_input = text_input.split('\n')
    label = []
    for idx, i in enumerate(text_input):
        if '#' in i:
            label.append(i[i.find('#') + 1:])
            text_input[idx] = i[:i.find('#')]
        else:
            label.append('No.{}'.format(idx))

    divided_text = [strong_divide(i) for i in text_input]
    text_left, text_right = [i[0] for i in divided_text], [i[1] for i in divided_text]
    inputs = tokenizer(text_left, padding=True, truncation=True, return_tensors="pt")
    with torch.no_grad():
      embeddings_left = model(**inputs, output_hidden_states=True, return_dict=True, sent_emb=True).pooler_output
    inputs = tokenizer(text_right, padding=True, truncation=True, return_tensors="pt")
    with torch.no_grad():
      embeddings_right = model(**inputs, output_hidden_states=True, return_dict=True, sent_emb=True).pooler_output

    merged_list = text_left + text_right
    merged_embed = torch.cat((embeddings_left, embeddings_right), dim=0)
    tsne_plot = TSNE_Plot(merged_list, merged_embed, label=label * 2, n_annotation_positions=len(merged_list))
    fig = tsne_plot.tsne_plot(n_sentence=len(merged_list), return_fig=True)
    return fig

with gr.Blocks() as demo:
    name = gr.inputs.Textbox(lines=20,
          placeholder='在此输入歌词,每一行为一个输入,如果需要输入歌词对应的歌名,请用#隔开\n例如:听雷声 滚滚 他默默 闭紧嘴唇 停止吟唱暮色与想念 他此刻沉痛而危险 听雷声 滚滚 他渐渐 感到胸闷 乌云阻拦明月涌河湾 他起身独立向荒原#河北墨麒麟')
    output = gr.Plot()
    btn = gr.Button("Generate")
    btn.click(fn=generate_image, inputs=name, outputs=output, api_name="generate-image")

demo.launch(debug=True)