saily commited on
Commit
495e47a
·
1 Parent(s): fb15db5

class for model single

Browse files
Files changed (1) hide show
  1. ebart.py +55 -35
ebart.py CHANGED
@@ -1,46 +1,66 @@
1
  import spaces
2
  import torch
3
- from transformers import PegasusForConditionalGeneration, PegasusTokenizer
4
- #from tokenizers_pegasus import PegasusTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  @spaces.GPU
7
  def generate_summary(text, max_length=180, min_length=64):
8
- # 加载标记器和模型
9
- model = PegasusForConditionalGeneration.from_pretrained("IDEA-CCNL/Randeng-Pegasus-523M-Summary-Chinese-V1")
10
- tokenizer = PegasusTokenizer.from_pretrained("IDEA-CCNL/Randeng-Pegasus-523M-Summary-Chinese-V1")
11
-
12
- # 将模型移动到GPU
13
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
- model.to(device)
15
-
16
- # 进行标记化并将输入数据移动到GPU
17
- inputs = tokenizer(text, max_length=1024, truncation=True, return_tensors="pt").to(device)
18
-
19
- # 生成摘要
20
- summary_ids = model.generate(
21
- inputs["input_ids"],
22
- max_length=max_length,
23
- min_length=min_length,
24
- num_beams=4,
25
- early_stopping=True
26
- )
27
-
28
- # 解码并返回摘要
29
- clean_summary = tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
30
- return clean_summary
31
 
32
  if __name__ == "__main__":
33
  text = (
34
- "2023年3月16日我们从黑龙江大庆来到湖南长沙市长沙县福中路77号湖南省富达日化有限公司,"
35
- "其宣传的特殊配方洗衣液比立白和蓝月亮的去污效果要好很多,还有油污净自称中国去污第一名,"
36
- "做了一些去除废机油的实验,油污净清洗废机油,洗洗液去除废机油,洗完后直接排入城市管网的下水池,"
37
- "每天都在进行相关测试,废机油属于危险废物,严重危害公共环境,请湖南环保局对其污染环境进行查处。"
38
- "其公司宣传材料存在大量虚假宣传,夸大其词,感觉就是个传销组织,说其公司有妆字号资质,药字号资质,"
39
- "全国工业产品餐具用洗涤剂资质,声称其设备是纯净水设备,是否有涉水批件,是否有消字号证件,"
40
- "其消字号所有产品是否都进行备案和匹配的检测报告,请湖南市场监督管理局对其进行查处,"
41
- "1997年到现在坑害全国百姓加盟其公司,请湖南商务局查处其是否具有特许经营资质,自称每年营业额1亿元多元,"
42
- "从97年坑害到23年大量的客户没有开局相应的发票,存在重大偷税漏税嫌疑,请湖南税务机关对其进行查处!"
43
- "还有其出口的设备,渠道是否正规,是白关,灰关,还是黑关,请湖南海关相关部门对其进行查处。"
44
  )
45
  summary = generate_summary(text, max_length=128, min_length=64)
46
  print(summary)
 
1
  import spaces
2
  import torch
3
+ from transformers import PegasusForConditionalGeneration
4
+ from tokenizers_pegasus import PegasusTokenizer
5
+
6
+ class PegasusSummarizer:
7
+ _instance = None
8
+
9
+ def __new__(cls, *args, **kwargs):
10
+ if cls._instance is None:
11
+ cls._instance = super().__new__(cls)
12
+ cls._instance._init_model()
13
+ return cls._instance
14
+
15
+ def _init_model(self):
16
+ # 加载标记器和模型
17
+ model_name = "IDEA-CCNL/Randeng-Pegasus-523M-Summary-Chinese-V1"
18
+ self.model = PegasusForConditionalGeneration.from_pretrained(model_name)
19
+ self.tokenizer = PegasusTokenizer.from_pretrained(model_name)
20
+
21
+ # 将模型移动到GPU
22
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ self.model.to(self.device)
24
+
25
+ def generate_summary(self, text, max_length=180, min_length=64):
26
+ # 进行标记化并将输入数据移动到GPU
27
+ inputs = self.tokenizer(text, max_length=1024, truncation=True, return_tensors="pt").to(self.device)
28
+
29
+ # 生成摘要
30
+ summary_ids = self.model.generate(
31
+ inputs["input_ids"],
32
+ max_length=max_length,
33
+ min_length=min_length,
34
+ num_beams=4,
35
+ early_stopping=True,
36
+ temperature=0.7,
37
+ top_k=50,
38
+ top_p=0.9,
39
+ repetition_penalty=2.0,
40
+ length_penalty=1.0,
41
+ no_repeat_ngram_size=3,
42
+ num_return_sequences=1,
43
+ do_sample=True
44
+ )
45
+
46
+ # 解码并返回摘要
47
+ clean_summary = self.tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
48
+
49
+ # 处理并过滤掉不需要的特殊标记
50
+ special_tokens = ['<pad>', '<unk>', '</s>']
51
+ for token in special_tokens:
52
+ clean_summary = clean_summary.replace(token, '')
53
+
54
+ return clean_summary
55
 
56
  @spaces.GPU
57
  def generate_summary(text, max_length=180, min_length=64):
58
+ summarizer = PegasusSummarizer()
59
+ return summarizer.generate_summary(text, max_length, min_length)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  if __name__ == "__main__":
62
  text = (
63
+ "东四路西侧之前有划分免费停车位,为什么后面被撤销,而道路东侧有划分免费停车车位,附近小区车位紧张导致很难找到停车位。"
 
 
 
 
 
 
 
 
 
64
  )
65
  summary = generate_summary(text, max_length=128, min_length=64)
66
  print(summary)