|
|
|
|
|
|
|
|
|
@@ -8,6 +8,7 @@ |
|
|
|
# here put the import lib |
|
import os |
|
+import pathlib |
|
import sys |
|
import math |
|
import random |
|
@@ -58,7 +59,8 @@ class CoglmStrategy: |
|
self._is_done = False |
|
self.outlier_count_down = torch.zeros(16) |
|
self.vis_list = [[]for i in range(16)] |
|
- self.cluster_labels = torch.tensor(np.load('cluster_label2.npy'), device='cuda', dtype=torch.long) |
|
+ cluster_label_path = pathlib.Path(__file__).parent / 'cluster_label2.npy' |
|
+ self.cluster_labels = torch.tensor(np.load(cluster_label_path), device='cuda', dtype=torch.long) |
|
self.start_pos = -1 |
|
self.white_cluster = [] |
|
# self.fout = open('tmp.txt', 'w') |
|
@@ -98,4 +100,4 @@ class CoglmStrategy: |
|
|
|
def finalize(self, tokens, mems): |
|
self._is_done = False |
|
- return tokens, mems |
|
\ No newline at end of file |
|
+ return tokens, mems |
|
|
|
|
|
|
|
|
|
@@ -8,6 +8,7 @@ |
|
|
|
# here put the import lib |
|
import os |
|
+import pathlib |
|
import sys |
|
import math |
|
import random |
|
@@ -28,7 +29,8 @@ class IterativeEntfilterStrategy: |
|
self.invalid_slices = invalid_slices |
|
self.temperature = temperature |
|
self.topk = topk |
|
- self.cluster_labels = torch.tensor(np.load('cluster_label2.npy'), device='cuda', dtype=torch.long) |
|
+ cluster_label_path = pathlib.Path(__file__).parents[1] / 'cluster_label2.npy' |
|
+ self.cluster_labels = torch.tensor(np.load(cluster_label_path), device='cuda', dtype=torch.long) |
|
|
|
|
|
def forward(self, logits_, tokens, temperature=None, entfilter=None, filter_topk=5, temperature2=None): |
|
|