hungdungn47 commited on
Commit
0e04b12
·
1 Parent(s): e44af84

add infer vit5

Browse files
Files changed (3) hide show
  1. app.py +23 -25
  2. infer_concat.py +109 -0
  3. requirements.txt +7 -7
app.py CHANGED
@@ -1,8 +1,14 @@
1
  import streamlit as st
2
  from io import StringIO
3
  from chdg_inference import infer
 
4
 
5
- st.title("Tóm tắt đa văn bản tiếng Việt")
 
 
 
 
 
6
 
7
  # Initialize session state
8
  if 'num_docs' not in st.session_state:
@@ -14,40 +20,32 @@ if 'docs' not in st.session_state:
14
  def add_text_area():
15
  st.session_state.num_docs += 1
16
 
17
-
18
  # Button to add a new text area
19
- st.button("Thêm văn bản", on_click=add_text_area)
20
 
21
  # Display text areas for document input
22
  for i in range(st.session_state.num_docs):
23
- doc = st.text_area(f"Văn bản {i+1}", key=f"doc_{i}", height=200)
24
  doc.replace('\r', '\n')
 
25
  if len(st.session_state.docs) <= i:
26
  st.session_state.docs.append(doc)
27
  else:
28
  st.session_state.docs[i] = doc
29
- # Display the documents for verification
30
- # st.write("**Entered Documents:**")
31
- # st.write(st.session_state.docs)
32
- # uploaded_file = st.file_uploader(label="Chọn file văn bản")
33
 
34
- category = st.selectbox("Chọn chủ để của văn bản: ", ['Giáo dục', 'Giải trí - Thể thao', 'Khoa học - Công nghệ', 'Kinh tế', 'Pháp luật', 'Thế giới', 'Văn hóa - Xã hội', 'Đời sống'])
35
 
36
  def summarize():
37
- # if uploaded_file is not None:
38
- # stringio = StringIO(uploaded_file.getvalue().decode("utf-8"))
39
- # full_text = stringio.read()
40
- # summ, docs = infer(full_text, category)
41
- # st.subheader("Kết quả: ")
42
- # st.write(summ)
43
- # st.subheader("Docs: ")
44
- # st.write(docs)
45
- # else:
46
- # st.error("Hãy tải file văn bản lên")
47
- summ, docs = infer(st.session_state.docs, category)
48
- st.subheader("Kết quả")
49
- st.write(summ)
50
- st.write(docs)
51
-
52
- if st.button("Tóm tắt"):
53
  summarize()
 
1
  import streamlit as st
2
  from io import StringIO
3
  from chdg_inference import infer
4
+ from infer_concat import vit5_infer
5
 
6
+ st.set_page_config(layout="wide")
7
+ st.title("Tóm tắt Đa văn bản Tiếng Việt")
8
+
9
+ col1, col2 = st.columns([1, 1])
10
+ col2_title, = col2.columns(1)
11
+ col2_chdg, col2_vit5 = col2.columns(2)
12
 
13
  # Initialize session state
14
  if 'num_docs' not in st.session_state:
 
20
  def add_text_area():
21
  st.session_state.num_docs += 1
22
 
 
23
  # Button to add a new text area
24
+ col1.button("Thêm văn bản", on_click=add_text_area)
25
 
26
  # Display text areas for document input
27
  for i in range(st.session_state.num_docs):
28
+ doc = col1.text_area(f"Văn bản {i+1}", key=f"doc_{i}", height=150)
29
  doc.replace('\r', '\n')
30
+ doc.replace('\"', "'")
31
  if len(st.session_state.docs) <= i:
32
  st.session_state.docs.append(doc)
33
  else:
34
  st.session_state.docs[i] = doc
 
 
 
 
35
 
36
+ category = col1.selectbox("Chọn chủ để của văn bản: ", ['Giáo dục', 'Giải trí - Thể thao', 'Khoa học - Công nghệ', 'Kinh tế', 'Pháp luật', 'Thế giới', 'Văn hóa - Xã hội', 'Đời sống'])
37
 
38
  def summarize():
39
+ summ, _ = infer(st.session_state.docs, category)
40
+ with col2.container():
41
+ col2_title.subheader("Kết quả: ")
42
+ col2_title.write("\n")
43
+
44
+ with col2.container():
45
+ col2_chdg.write("CHDG:")
46
+ col2_chdg.write(summ)
47
+ summ_vit5 = vit5_infer(st.session_state.docs)
48
+ col2_vit5.write(summ_vit5)
49
+
50
+ if col1.button("Tóm tắt"):
 
 
 
 
51
  summarize()
infer_concat.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # create dataset class
2
+ from torch.utils.data import Dataset, DataLoader
3
+ import torch
4
+ import json
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
+ import time
7
+
8
+
9
+ class Dataset4Summarization(Dataset):
10
+ def __init__(self, data, tokenizer, max_length=1024*3, chunk_length =1024):
11
+ self.data = data
12
+ self.tokenizer = tokenizer
13
+ self.max_length = max_length
14
+ self.chunk_length = chunk_length
15
+
16
+ def __len__(self):
17
+ return len(self.data)
18
+
19
+ def chunking(self, text):
20
+ chunks = []
21
+ for i in range(0, self.max_length, self.chunk_length):
22
+ chunks.append(text[i:i+self.chunk_length])
23
+ return chunks
24
+
25
+ def __getitem__(self, idx):
26
+ sample = self.data[idx]
27
+ inputs = self.tokenizer(sample, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
28
+
29
+ list_chunk = self.chunking(inputs['input_ids'].squeeze())
30
+ list_attention_mask = self.chunking(inputs['attention_mask'].squeeze())
31
+
32
+
33
+ return {
34
+ 'list_input_ids': list_chunk,
35
+ 'list_att_mask' : list_attention_mask,
36
+ }
37
+
38
+
39
+ def process_data_infer(data):
40
+ single_documents = data.get('single_documents', [])
41
+
42
+
43
+ result = []
44
+ for doc in single_documents:
45
+ raw_text = doc.get('raw_text', '')
46
+ result.append(raw_text)
47
+
48
+ return " ".join(result)
49
+
50
+
51
+ def processing_data_infer(input_file):
52
+ all_results = []
53
+
54
+ with open(input_file, 'r', encoding='utf-8') as file:
55
+ for line in file:
56
+ data = json.loads(line.strip())
57
+ result = process_data_infer(data)
58
+ all_results.append(result)
59
+
60
+ return all_results
61
+
62
+ # Load model and tokenizer
63
+ tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-base-vietnews-summarization")
64
+ model = AutoModelForSeq2SeqLM.from_pretrained("VietAI/vit5-base-vietnews-summarization")
65
+
66
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
67
+ model.to(device)
68
+
69
+ model.load_state_dict(torch.load("./weight_cp19_model.pth", map_location=torch.device('cpu')))
70
+
71
+ # For other demo purpose, you just need to make sure data is list of documents [document1, document2]
72
+
73
+ # batch_size need to be 1,
74
+ @torch.no_grad()
75
+ def infer_2_hier(model, data_loader, device, tokenizer):
76
+ model.eval()
77
+ start = time.time()
78
+ all_summaries = []
79
+ for iter in data_loader:
80
+ summaries = []
81
+ inputs = iter['list_input_ids']
82
+ att_mask = iter['list_att_mask']
83
+
84
+ for i in range(len(inputs)):
85
+ # Check if the input tensor is all zeros
86
+ if torch.all(inputs[i] == 0):
87
+ # If the input is all zeros, skip this iteration
88
+ continue
89
+ else:
90
+ summary = model.generate(inputs[i].to(device),
91
+ attention_mask=att_mask[i].to(device),
92
+ max_length=128,
93
+ num_beams=12,
94
+ num_return_sequences=1)
95
+ summaries.append(summary)
96
+ summaries = torch.cat(summaries, dim = 1)
97
+ for k in summaries:
98
+ all_summaries.append(tokenizer.decode(k, skip_special_tokens=True))
99
+
100
+
101
+ end = time.time()
102
+ print(f"Time: {end-start}")
103
+ return all_summaries
104
+
105
+ def vit5_infer(data):
106
+ dataset = Dataset4Summarization(data, tokenizer)
107
+ data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=2)
108
+ result = infer_2_hier(model, data_loader, device, tokenizer)
109
+ return result
requirements.txt CHANGED
@@ -1,7 +1,7 @@
1
- torch
2
- rouge
3
- transformers
4
- underthesea
5
- numpy
6
- pandas
7
- scikit-learn
 
1
+ torch==2.1.2
2
+ rouge==1.0.1
3
+ transformers==4.39.2
4
+ underthesea==6.8.4
5
+ numpy==1.25.1
6
+ pandas==2.1.1
7
+ scikit-learn==1.3.0