File size: 1,754 Bytes
71f183c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
from tqdm import tqdm

class GloVe():

   def __init__(self, file_path):
       self.dimension = None
       self.embedding = dict()
       with open(file_path, 'r') as f:
           for line in tqdm(f.readlines()):
               strs = line.strip().split(' ')
               word = strs[0]
               vector = torch.FloatTensor(list(map(float, strs[1:])))
               self.embedding[word] = vector
               if self.dimension is None:
                   self.dimension = len(vector)

   def _fix_word(self, word):
       terms = word.replace('_', ' ').split(' ')
       ret = self.zeros()
       cnt = 0
       for term in terms:
           v = self.embedding.get(term)
           if v is None:
               subterms = term.split('-')
               subterm_sum = self.zeros()
               subterm_cnt = 0
               for subterm in subterms:
                   subv = self.embedding.get(subterm)
                   if subv is not None:
                       subterm_sum += subv
                       subterm_cnt += 1
               if subterm_cnt > 0:
                   v = subterm_sum / subterm_cnt
           if v is not None:
               ret += v
               cnt += 1
       return ret / cnt if cnt > 0 else None

   def __getitem__(self, words):
       if type(words) is str:
           words = [words]
       ret = self.zeros()
       cnt = 0
       for word in words:
           v = self.embedding.get(word)
           if v is None:
               v = self._fix_word(word)
           if v is not None:
               ret += v
               cnt += 1
       if cnt > 0:
           return ret / cnt
       else:
           return self.zeros()

   def zeros(self):
       return torch.zeros(self.dimension)