|
from typing import Any |
|
import torch |
|
from torch import nn |
|
from torch.utils.data import Dataset, DataLoader |
|
import numpy as np |
|
from os import listdir |
|
from os.path import isfile, join |
|
|
|
if __package__ == None or __package__ == "": |
|
from utils import tag_training_data, get_upenn_tags_dict, parse_tags |
|
else: |
|
from .utils import tag_training_data, get_upenn_tags_dict, parse_tags |
|
|
|
|
|
class SegmentorDataset(Dataset): |
|
def __init__(self, datapoints): |
|
self.datapoints = [(torch.from_numpy(k).float(), torch.tensor([t]).float()) for k, t in datapoints] |
|
|
|
def __len__(self): |
|
return len(self.datapoints) |
|
|
|
def __getitem__(self, idx): |
|
return self.datapoints[idx][0], self.datapoints[idx][1] |
|
|
|
class RNN(nn.Module): |
|
def __init__(self, input_size, hidden_size, num_layers, device=None): |
|
super(RNN, self).__init__() |
|
|
|
if device == None: |
|
if torch.cuda.is_available(): |
|
self.device = "cuda" |
|
else: |
|
self.device = "cpu" |
|
else: |
|
self.device = device |
|
|
|
self.num_layers = num_layers |
|
self.hidden_size = hidden_size |
|
self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) |
|
|
|
self.fc = nn.Linear(hidden_size, 1) |
|
|
|
def forward(self, x): |
|
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size, device=self.device) |
|
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size, device=self.device) |
|
out, _ = self.rnn(x, (h0, c0)) |
|
|
|
out = out[:, -1, :] |
|
|
|
out = self.fc(out) |
|
|
|
return out |
|
|
|
|
|
class SegmentorDatasetDirectTag(Dataset): |
|
def __init__(self, document_root: str): |
|
self.tags_dict = get_upenn_tags_dict() |
|
self.datapoints = [] |
|
self.eye = np.eye(len(self.tags_dict)) |
|
|
|
files = listdir(document_root) |
|
for f in files: |
|
if f.endswith(".txt"): |
|
fname = join(document_root, f) |
|
print(f"Loaded datafile: {fname}") |
|
reconstructed_tags = tag_training_data(fname) |
|
input, tag = parse_tags(reconstructed_tags) |
|
self.datapoints.append(( |
|
np.array(input), |
|
np.array(tag) |
|
)) |
|
|
|
def __len__(self): |
|
return len(self.datapoints) |
|
|
|
def __getitem__(self, idx): |
|
item = self.datapoints[idx] |
|
return torch.from_numpy(self.eye[item[0]]).float(), torch.from_numpy(item[1]).float() |
|
|
|
|
|
class SegmentorDatasetNonEmbed(Dataset): |
|
def __init__(self, document_root: str): |
|
self.datapoints = [] |
|
|
|
files = listdir(document_root) |
|
for f in files: |
|
if f.endswith(".txt"): |
|
fname = join(document_root, f) |
|
print(f"Loaded datafile: {fname}") |
|
reconstructed_tags = tag_training_data(fname) |
|
input, tag = parse_tags(reconstructed_tags) |
|
self.datapoints.append(( |
|
np.array(input), |
|
np.array(tag) |
|
)) |
|
|
|
def __len__(self): |
|
return len(self.datapoints) |
|
|
|
def __getitem__(self, idx): |
|
item = self.datapoints[idx] |
|
return torch.from_numpy(item[0]).int(), torch.from_numpy(item[1]).float() |
|
|
|
class BidirLSTMSegmenter(nn.Module): |
|
def __init__(self, input_size, hidden_size, num_layers, device = None): |
|
super(BidirLSTMSegmenter, self).__init__() |
|
|
|
if device == None: |
|
if torch.cuda.is_available(): |
|
self.device = "cuda" |
|
else: |
|
self.device = "cpu" |
|
else: |
|
self.device = device |
|
|
|
self.num_layers = num_layers |
|
self.hidden_size = hidden_size |
|
self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True, device = self.device) |
|
|
|
self.fc = nn.Linear(2*hidden_size, 1, device = self.device) |
|
self.final = nn.Sigmoid() |
|
|
|
def forward(self, x): |
|
h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size, device=self.device) |
|
c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size, device=self.device) |
|
out, _ = self.rnn(x, (h0, c0)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
out_fced = self.fc(out)[:, :, 0] |
|
|
|
|
|
|
|
return self.final(out_fced) |
|
|
|
class BidirLSTMSegmenterWithEmbedding(nn.Module): |
|
def __init__(self, input_size, embedding_size, hidden_size, num_layers, device = None): |
|
super(BidirLSTMSegmenterWithEmbedding, self).__init__() |
|
|
|
if device == None: |
|
if torch.cuda.is_available(): |
|
self.device = "cuda" |
|
else: |
|
self.device = "cpu" |
|
else: |
|
self.device = device |
|
|
|
self.num_layers = num_layers |
|
self.hidden_size = hidden_size |
|
self.embedding_size = embedding_size |
|
|
|
self.embedding = nn.Embedding(input_size, embedding_dim=embedding_size, device = self.device) |
|
self.rnn = nn.LSTM(embedding_size, hidden_size, num_layers, batch_first=True, bidirectional=True, device = self.device) |
|
|
|
self.fc = nn.Linear(2*hidden_size, 1, device = self.device) |
|
self.final = nn.Sigmoid() |
|
|
|
def forward(self, x): |
|
h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size, device=self.device) |
|
c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size, device=self.device) |
|
embedded = self.embedding(x) |
|
out, _ = self.rnn(embedded, (h0, c0)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
out_fced = self.fc(out)[:, :, 0] |
|
|
|
|
|
|
|
return self.final(out_fced) |
|
|
|
def collate_fn_padd(batch): |
|
''' |
|
Padds batch of variable length |
|
|
|
note: it converts things ToTensor manually here since the ToTensor transform |
|
assume it takes in images rather than arbitrary tensors. |
|
''' |
|
|
|
inputs = [i[0] for i in batch] |
|
tags = [i[1] for i in batch] |
|
|
|
padded_input = torch.nn.utils.rnn.pad_sequence(inputs, batch_first=True) |
|
combined_outputs = torch.nn.utils.rnn.pad_sequence(tags, batch_first=True) |
|
|
|
|
|
return (padded_input, combined_outputs) |
|
|
|
def get_dataloader(dataset: SegmentorDataset, batch_size): |
|
return DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn_padd) |
|
|
|
def train_model(model: RNN, |
|
dataset, |
|
lr = 1e-3, |
|
num_epochs = 3, |
|
batch_size = 100, |
|
): |
|
train_loader = get_dataloader(dataset, batch_size=batch_size) |
|
|
|
n_total_steps = len(train_loader) |
|
criterion = nn.MSELoss() |
|
optimizer = torch.optim.AdamW(model.parameters(), lr=lr) |
|
device = model.device |
|
|
|
for epoch in range(num_epochs): |
|
for i, (input, tags) in enumerate(train_loader): |
|
input = input.to(device) |
|
tags = tags.to(device) |
|
|
|
outputs = model(input) |
|
loss = criterion(outputs, tags) |
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
if i%100 == 0: |
|
print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss [{loss.item():.4f}]") |
|
|
|
def train_bidirlstm_model(model: BidirLSTMSegmenter, |
|
dataset: SegmentorDatasetDirectTag, |
|
lr = 1e-3, |
|
num_epochs = 3, |
|
batch_size = 1, |
|
): |
|
train_loader = get_dataloader(dataset, batch_size=batch_size) |
|
|
|
n_total_steps = len(train_loader) |
|
criterion = nn.BCELoss() |
|
optimizer = torch.optim.AdamW(model.parameters(), lr=lr) |
|
device = model.device |
|
|
|
for epoch in range(num_epochs): |
|
for i, (input, tags) in enumerate(train_loader): |
|
input = input.to(device) |
|
tags = tags.to(device) |
|
|
|
optimizer.zero_grad() |
|
|
|
outputs = model(input) |
|
|
|
loss = criterion(outputs, tags) |
|
|
|
loss.backward() |
|
optimizer.step() |
|
|
|
if i%10 == 0: |
|
print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss [{loss.item():.4f}]") |
|
|
|
def train_bidirlstm_embedding_model(model: BidirLSTMSegmenterWithEmbedding, |
|
dataset: SegmentorDatasetNonEmbed, |
|
lr = 1e-3, |
|
num_epochs = 3, |
|
batch_size = 1, |
|
): |
|
train_loader = get_dataloader(dataset, batch_size=batch_size) |
|
|
|
n_total_steps = len(train_loader) |
|
criterion = nn.BCELoss() |
|
optimizer = torch.optim.AdamW(model.parameters(), lr=lr) |
|
device = model.device |
|
|
|
for epoch in range(num_epochs): |
|
for i, (input, tags) in enumerate(train_loader): |
|
input = input.to(device) |
|
tags = tags.to(device) |
|
|
|
optimizer.zero_grad() |
|
|
|
outputs = model(input) |
|
|
|
loss = criterion(outputs, tags) |
|
|
|
loss.backward() |
|
optimizer.step() |
|
|
|
if i%10 == 0: |
|
print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss [{loss.item():.4f}]") |
|
|