lllchenlll commited on
Commit
5cd70f3
·
1 Parent(s): c3a5286

Upload 5 files

Browse files
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import openai
4
+
5
+ from sentence_transformers import SentenceTransformer
6
+ from langchain.prompts import PromptTemplate
7
+ from collections import Counter
8
+
9
+
10
+ def process(api, caption, category, asr, ocr):
11
+ openai.api_key = api
12
+ preference = "兴趣标签"
13
+ example = "例如,给定一个视频,它的\"标题\"为\"长安系最便宜的轿车,4W起很多人都看不上它,但我知道车只是代步工具,又需要什么面子呢" \
14
+ "!\",\"类别\"为\"汽车\",\"ocr\"为\"长安系最便宜的一款轿车\",\"asr\"为\"我不否认现在的国产和合资还有一定的差距," \
15
+ "但确实是他们让我们5万开了MP V8万开上了轿车,10万开张了ICV15万开张了大七座。\",\"{}\"生成机器人推断出合理的\"{}\"为\"" \
16
+ "长安轿车报价、最便宜的长安轿车、新款长安轿车\"。".format(preference, preference)
17
+
18
+ prompt = PromptTemplate(
19
+ input_variables=["preference", "caption", "ocr", "asr", "category", "example"],
20
+ template="你是一个视频的\"{preference}\"生成机器人,根据输入的视频标题、类别、ocr、asr推理出合理的\"{preference}\",以多个多"
21
+ "于两字的标签形式进行表达,以顿号隔开。{example}那么,给定一个新的视频,它的\"标题\"为\"{caption}\",\"类别\"为"
22
+ "\"{category}\",\"ocr\"为\"{ocr}\",\"asr\"为\"{asr}\",请推断出该视频的\"{preference}\":"
23
+ )
24
+
25
+ text = prompt.format(preference=preference, caption=caption, category=category, ocr=ocr, asr=asr, example=example)
26
+
27
+ try:
28
+ completion = openai.ChatCompletion.create(
29
+ model="gpt-3.5-turbo",
30
+ messages=[{"role": "user", "content": text}],
31
+ temperature=1.5,
32
+ n=5
33
+ )
34
+
35
+ res = []
36
+ for j in range(5):
37
+ ans = completion.choices[j].message["content"].strip()
38
+ ans = ans.replace("\n", "")
39
+ ans = ans.replace("。", "")
40
+ ans = ans.replace(",", "、")
41
+ res += ans.split('、')
42
+
43
+ tag_count = Counter(res)
44
+ tag_count = sorted(tag_count.items(), key=lambda x: x[1], reverse=True)[:10]
45
+
46
+ tags_embed = np.load('./tag_data/tags_embed.npy')
47
+ tags_dis = np.load('./tag_data/tags_dis.npy')
48
+
49
+ candidate_tags = [_[0] for _ in tag_count]
50
+ encoder = SentenceTransformer("hfl/chinese-roberta-wwm-ext-large")
51
+ candidate_tags_embed = encoder.encode(candidate_tags)
52
+ candidate_tags_dis = [np.sqrt(np.dot(_, _.T)) for _ in candidate_tags_embed]
53
+
54
+ scores = np.dot(candidate_tags_embed, tags_embed.T)
55
+ f = open('./tag_data/tags.txt', 'r')
56
+ all_tags = []
57
+ for line in f.readlines():
58
+ all_tags.append(line.strip())
59
+ f.close()
60
+
61
+ final_ans = []
62
+ for i in range(scores.shape[0]):
63
+ for j in range(scores.shape[1]):
64
+ score = scores[i][j] / (candidate_tags_dis[i] * tags_dis[j])
65
+ if score > 0.8:
66
+ final_ans.append(all_tags[j])
67
+
68
+ print(final_ans)
69
+
70
+ final_ans = Counter(final_ans)
71
+ final_ans = sorted(final_ans.items(), key=lambda x: x[1], reverse=True)[:5]
72
+ final_ans = [_[0] for _ in final_ans]
73
+
74
+ return "、".join(final_ans)
75
+
76
+ except:
77
+ return 'api error'
78
+
79
+
80
+ with gr.Blocks() as demo:
81
+ text_api = gr.Textbox(label='OpenAI API key')
82
+ text_caption = gr.Textbox(label='Caption')
83
+ text_category = gr.Textbox(label='Category')
84
+ text_asr = gr.Textbox(label='ASR')
85
+ text_ocr = gr.Textbox(label='OCR')
86
+
87
+ text_output = gr.Textbox(value='', label='Output')
88
+
89
+ btn = gr.Button(value='Submit')
90
+ btn.click(process, inputs=[text_api, text_caption, text_category, text_asr, text_ocr], outputs=[text_output])
91
+
92
+
93
+ if __name__ == "__main__":
94
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio==3.24.1
2
+ sentence-transformers==2.2.2
3
+ openai==0.27.4
4
+ langchain==0.0.133
tag_data/tags.txt ADDED
The diff for this file is too large to render. See raw diff
 
tag_data/tags_dis.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac67bff7196a2e1f8349a5f4e8efad564c521670c3ce4ac4d162e570241534b8
3
+ size 141908
tag_data/tags_embed.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ed5b1c3ed770571fe690bff611041e4d87bfc9bc0fa50e9c4f9b48273f5eb39
3
+ size 145182848