lalital commited on
Commit
f1f90b3
·
verified ·
1 Parent(s): 055dd66

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -15
app.py CHANGED
@@ -17,11 +17,7 @@ model = BartForConditionalGeneration.from_pretrained(
17
  tokenizer = AutoTokenizer.from_pretrained(
18
  'airesearch/wangchanbart-large',
19
  )
20
-
21
- text_summarize_pipeline = pipeline('text2text-generation',
22
- tokenizer=tokenizer,
23
- model=model)
24
-
25
  css_text = """<link rel="stylesheet" href="https://www.w3schools.com/w3css/4/w3.css">"""
26
 
27
  # def render_html(items: List[Dict]):
@@ -38,17 +34,26 @@ css_text = """<link rel="stylesheet" href="https://www.w3schools.com/w3css/4/w3.
38
 
39
  # return '<div class="w3-container">' + html_text + '</div>'
40
 
41
- def summarize(text: str):
 
 
 
 
 
 
 
 
42
 
43
- results = text_summarize_pipeline(text, max_length=1024)[0]
44
- print(f'results:\n {results}')
45
- # for i, result in enumerate(results):
46
- # results[i]['label'] = LABEL_MAPPING[result['label']]
47
- # results[i]['score'] = float(round(float(result['score']), 4))
48
- # html_text = 'css_text + results'
49
- html_text = '<p>' + results['generated_text'] + '</p>'
50
- print(html_text)
51
- return html_text
 
52
 
53
 
54
  demo = gr.Interface(fn=summarize,
 
17
  tokenizer = AutoTokenizer.from_pretrained(
18
  'airesearch/wangchanbart-large',
19
  )
20
+
 
 
 
 
21
  css_text = """<link rel="stylesheet" href="https://www.w3schools.com/w3css/4/w3.css">"""
22
 
23
  # def render_html(items: List[Dict]):
 
34
 
35
  # return '<div class="w3-container">' + html_text + '</div>'
36
 
37
+ def summarize(input_text: str, model, tokenizer, num_beams=5) -> str:
38
+
39
+ input_length = len(tokenizer.tokenize(input_text))
40
+ inputs = tokenizer(input_text,
41
+ return_tensors="pt",
42
+ max_length=1024,
43
+ truncation=True,
44
+ padding='max_length')
45
+
46
 
47
+ predicted_token_ids = model.generate(inputs['input_ids'],
48
+ num_beams=num_beams,
49
+ min_length=32,
50
+ max_length=128,
51
+ length_penalty=10.0)
52
+ preditected_summary = tokenizer.decode(predicted_token_ids[0], skip_special_tokens=True)
53
+
54
+ return '<p>' + preditected_summary + '</p>'
55
+
56
+
57
 
58
 
59
  demo = gr.Interface(fn=summarize,