metricv commited on
Commit
96a7d84
·
1 Parent(s): d06f65e

Update segmenter

Browse files
Files changed (4) hide show
  1. data +1 -1
  2. model.py +30 -10
  3. segmenter.ckpt +1 -1
  4. train.py +1 -1
data CHANGED
@@ -1 +1 @@
1
- Subproject commit 33c57a3cafbdb46b4cc7db7f08695d63b52d6668
 
1
+ Subproject commit dd266799aedd72e6381b368eacbe2767b6174aad
model.py CHANGED
@@ -5,6 +5,8 @@ from torch.utils.data import Dataset, DataLoader
5
  import numpy as np
6
  from os import listdir
7
  from os.path import isfile, join
 
 
8
 
9
  if __package__ == None or __package__ == "":
10
  from utils import tag_training_data, get_upenn_tags_dict, parse_tags
@@ -79,20 +81,38 @@ class SegmentorDatasetDirectTag(Dataset):
79
 
80
  # The same dataset without one-hot embedding of the input.
81
  class SegmentorDatasetNonEmbed(Dataset):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  def __init__(self, document_root: str):
83
  self.datapoints = []
84
 
85
  files = listdir(document_root)
86
- for f in files:
87
- if f.endswith(".txt"):
88
- fname = join(document_root, f)
89
- print(f"Loaded datafile: {fname}")
90
- reconstructed_tags = tag_training_data(fname)
91
- input, tag = parse_tags(reconstructed_tags)
92
- self.datapoints.append((
93
- np.array(input),
94
- np.array(tag)
95
- ))
 
 
 
 
96
 
97
  def __len__(self):
98
  return len(self.datapoints)
 
5
  import numpy as np
6
  from os import listdir
7
  from os.path import isfile, join
8
+ import concurrent
9
+ import itertools
10
 
11
  if __package__ == None or __package__ == "":
12
  from utils import tag_training_data, get_upenn_tags_dict, parse_tags
 
81
 
82
  # The same dataset without one-hot embedding of the input.
83
  class SegmentorDatasetNonEmbed(Dataset):
84
+ @staticmethod
85
+ def read_file(f: str, document_root: str):
86
+ if f.endswith(".txt"):
87
+ fname = join(document_root, f)
88
+ print(f"Loaded datafile: {fname}")
89
+ reconstructed_tags = tag_training_data(fname)
90
+ input, tag = parse_tags(reconstructed_tags)
91
+ return [(
92
+ np.array(input),
93
+ np.array(tag)
94
+ )]
95
+ else:
96
+ return []
97
+
98
  def __init__(self, document_root: str):
99
  self.datapoints = []
100
 
101
  files = listdir(document_root)
102
+ with concurrent.futures.ProcessPoolExecutor() as pool:
103
+ out = pool.map(SegmentorDatasetNonEmbed.read_file, files, itertools.repeat(document_root))
104
+
105
+ self.datapoints = list(itertools.chain.from_iterable(out))
106
+ # for f in files:
107
+ # if f.endswith(".txt"):
108
+ # fname = join(document_root, f)
109
+ # print(f"Loaded datafile: {fname}")
110
+ # reconstructed_tags = tag_training_data(fname)
111
+ # input, tag = parse_tags(reconstructed_tags)
112
+ # self.datapoints.append((
113
+ # np.array(input),
114
+ # np.array(tag)
115
+ # ))
116
 
117
  def __len__(self):
118
  return len(self.datapoints)
segmenter.ckpt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a8e6209584d0021684bb3a09ec1b717843f3086dfcc6411c57276f743f8e62fa
3
  size 10584544
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:005053e2036ac4a30364cdb81501140ef2ca238bee0f9a1a28fc5a4603d725f6
3
  size 10584544
train.py CHANGED
@@ -26,6 +26,6 @@ if __name__ == "__main__":
26
 
27
  model.to(device)
28
 
29
- train_bidirlstm_embedding_model(model, dataset, num_epochs=150, batch_size=2)
30
 
31
  torch.save(model.state_dict(), "segmenter.ckpt")
 
26
 
27
  model.to(device)
28
 
29
+ train_bidirlstm_embedding_model(model, dataset, num_epochs=100, batch_size=2)
30
 
31
  torch.save(model.state_dict(), "segmenter.ckpt")