|
|
|
from torch.utils.data import Dataset, DataLoader |
|
import torch |
|
import json |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
import time |
|
|
|
|
|
class Dataset4Summarization(Dataset): |
|
def __init__(self, data, tokenizer, max_length=1024*3, chunk_length =1024): |
|
self.data = data |
|
self.tokenizer = tokenizer |
|
self.max_length = max_length |
|
self.chunk_length = chunk_length |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def chunking(self, text): |
|
chunks = [] |
|
for i in range(0, self.max_length, self.chunk_length): |
|
chunks.append(text[i:i+self.chunk_length]) |
|
return chunks |
|
|
|
def __getitem__(self, idx): |
|
sample = self.data[idx] |
|
inputs = self.tokenizer(sample, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length) |
|
|
|
list_chunk = self.chunking(inputs['input_ids'].squeeze()) |
|
list_attention_mask = self.chunking(inputs['attention_mask'].squeeze()) |
|
|
|
|
|
return { |
|
'list_input_ids': list_chunk, |
|
'list_att_mask' : list_attention_mask, |
|
} |
|
|
|
|
|
def process_data_infer(data): |
|
single_documents = data.get('single_documents', []) |
|
|
|
|
|
result = [] |
|
for doc in single_documents: |
|
raw_text = doc.get('raw_text', '') |
|
result.append(raw_text) |
|
|
|
return " ".join(result) |
|
|
|
|
|
def processing_data_infer(input_file): |
|
all_results = [] |
|
|
|
with open(input_file, 'r', encoding='utf-8') as file: |
|
for line in file: |
|
data = json.loads(line.strip()) |
|
result = process_data_infer(data) |
|
all_results.append(result) |
|
|
|
return all_results |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-base-vietnews-summarization") |
|
model = AutoModelForSeq2SeqLM.from_pretrained("VietAI/vit5-base-vietnews-summarization") |
|
|
|
device = torch.device('cpu') |
|
model.to(device) |
|
|
|
model.load_state_dict(torch.load("./weight_cp19_model.pth", map_location=torch.device('cpu'))) |
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
def infer_2_hier(model, data_loader, device, tokenizer): |
|
model.eval() |
|
start = time.time() |
|
all_summaries = [] |
|
for iter in data_loader: |
|
summaries = [] |
|
inputs = iter['list_input_ids'] |
|
att_mask = iter['list_att_mask'] |
|
|
|
for i in range(len(inputs)): |
|
print(f"input {i}") |
|
|
|
if torch.all(inputs[i] == 0): |
|
|
|
continue |
|
else: |
|
summary = model.generate(inputs[i].to(device), |
|
attention_mask=att_mask[i].to(device), |
|
max_length=128, |
|
num_beams=4, |
|
num_return_sequences=1, no_repeat_ngram_size=3) |
|
summaries.append(summary) |
|
summaries = torch.cat(summaries, dim = 1) |
|
|
|
all_summaries.append(tokenizer.decode(summaries.squeeze(), skip_special_tokens=True)) |
|
|
|
end = time.time() |
|
print(f"Time: {end-start}") |
|
return all_summaries |
|
|
|
def vit5_infer(data): |
|
dataset = Dataset4Summarization(data, tokenizer) |
|
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1) |
|
result = infer_2_hier(model, data_loader, device, tokenizer) |
|
return result |