class for model single
Browse files
ebart.py
CHANGED
@@ -1,46 +1,66 @@
|
|
1 |
import spaces
|
2 |
import torch
|
3 |
-
from transformers import PegasusForConditionalGeneration
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
@spaces.GPU
|
7 |
def generate_summary(text, max_length=180, min_length=64):
|
8 |
-
|
9 |
-
|
10 |
-
tokenizer = PegasusTokenizer.from_pretrained("IDEA-CCNL/Randeng-Pegasus-523M-Summary-Chinese-V1")
|
11 |
-
|
12 |
-
# 将模型移动到GPU
|
13 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
14 |
-
model.to(device)
|
15 |
-
|
16 |
-
# 进行标记化并将输入数据移动到GPU
|
17 |
-
inputs = tokenizer(text, max_length=1024, truncation=True, return_tensors="pt").to(device)
|
18 |
-
|
19 |
-
# 生成摘要
|
20 |
-
summary_ids = model.generate(
|
21 |
-
inputs["input_ids"],
|
22 |
-
max_length=max_length,
|
23 |
-
min_length=min_length,
|
24 |
-
num_beams=4,
|
25 |
-
early_stopping=True
|
26 |
-
)
|
27 |
-
|
28 |
-
# 解码并返回摘要
|
29 |
-
clean_summary = tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
|
30 |
-
return clean_summary
|
31 |
|
32 |
if __name__ == "__main__":
|
33 |
text = (
|
34 |
-
"
|
35 |
-
"其宣传的特殊配方洗衣液比立白和蓝月亮的去污效果要好很多,还有油污净自称中国去污第一名,"
|
36 |
-
"做了一些去除废机油的实验,油污净清洗废机油,洗洗液去除废机油,洗完后直接排入城市管网的下水池,"
|
37 |
-
"每天都在进行相关测试,废机油属于危险废物,严重危害公共环境,请湖南环保局对其污染环境进行查处。"
|
38 |
-
"其公司宣传材料存在大量虚假宣传,夸大其词,感觉就是个传销组织,说其公司有妆字号资质,药字号资质,"
|
39 |
-
"全国工业产品餐具用洗涤剂资质,声称其设备是纯净水设备,是否有涉水批件,是否有消字号证件,"
|
40 |
-
"其消字号所有产品是否都进行备案和匹配的检测报告,请湖南市场监督管理局对其进行查处,"
|
41 |
-
"1997年到现在坑害全国百姓加盟其公司,请湖南商务局查处其是否具有特许经营资质,自称每年营业额1亿元多元,"
|
42 |
-
"从97年坑害到23年大量的客户没有开局相应的发票,存在重大偷税漏税嫌疑,请湖南税务机关对其进行查处!"
|
43 |
-
"还有其出口的设备,渠道是否正规,是白关,灰关,还是黑关,请湖南海关相关部门对其进行查处。"
|
44 |
)
|
45 |
summary = generate_summary(text, max_length=128, min_length=64)
|
46 |
print(summary)
|
|
|
1 |
import spaces
|
2 |
import torch
|
3 |
+
from transformers import PegasusForConditionalGeneration
|
4 |
+
from tokenizers_pegasus import PegasusTokenizer
|
5 |
+
|
6 |
+
class PegasusSummarizer:
|
7 |
+
_instance = None
|
8 |
+
|
9 |
+
def __new__(cls, *args, **kwargs):
|
10 |
+
if cls._instance is None:
|
11 |
+
cls._instance = super().__new__(cls)
|
12 |
+
cls._instance._init_model()
|
13 |
+
return cls._instance
|
14 |
+
|
15 |
+
def _init_model(self):
|
16 |
+
# 加载标记器和模型
|
17 |
+
model_name = "IDEA-CCNL/Randeng-Pegasus-523M-Summary-Chinese-V1"
|
18 |
+
self.model = PegasusForConditionalGeneration.from_pretrained(model_name)
|
19 |
+
self.tokenizer = PegasusTokenizer.from_pretrained(model_name)
|
20 |
+
|
21 |
+
# 将模型移动到GPU
|
22 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
23 |
+
self.model.to(self.device)
|
24 |
+
|
25 |
+
def generate_summary(self, text, max_length=180, min_length=64):
|
26 |
+
# 进行标记化并将输入数据移动到GPU
|
27 |
+
inputs = self.tokenizer(text, max_length=1024, truncation=True, return_tensors="pt").to(self.device)
|
28 |
+
|
29 |
+
# 生成摘要
|
30 |
+
summary_ids = self.model.generate(
|
31 |
+
inputs["input_ids"],
|
32 |
+
max_length=max_length,
|
33 |
+
min_length=min_length,
|
34 |
+
num_beams=4,
|
35 |
+
early_stopping=True,
|
36 |
+
temperature=0.7,
|
37 |
+
top_k=50,
|
38 |
+
top_p=0.9,
|
39 |
+
repetition_penalty=2.0,
|
40 |
+
length_penalty=1.0,
|
41 |
+
no_repeat_ngram_size=3,
|
42 |
+
num_return_sequences=1,
|
43 |
+
do_sample=True
|
44 |
+
)
|
45 |
+
|
46 |
+
# 解码并返回摘要
|
47 |
+
clean_summary = self.tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
|
48 |
+
|
49 |
+
# 处理并过滤掉不需要的特殊标记
|
50 |
+
special_tokens = ['<pad>', '<unk>', '</s>']
|
51 |
+
for token in special_tokens:
|
52 |
+
clean_summary = clean_summary.replace(token, '')
|
53 |
+
|
54 |
+
return clean_summary
|
55 |
|
56 |
@spaces.GPU
|
57 |
def generate_summary(text, max_length=180, min_length=64):
|
58 |
+
summarizer = PegasusSummarizer()
|
59 |
+
return summarizer.generate_summary(text, max_length, min_length)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
if __name__ == "__main__":
|
62 |
text = (
|
63 |
+
"东四路西侧之前有划分免费停车位,为什么后面被撤销,而道路东侧有划分免费停车车位,附近小区车位紧张导致很难找到停车位。"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
)
|
65 |
summary = generate_summary(text, max_length=128, min_length=64)
|
66 |
print(summary)
|