File size: 7,913 Bytes
db26c81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import torch

from torch import nn
from collections import OrderedDict
from torch.utils.checkpoint import checkpoint
from .feature import *
from .pyenv import PyEnv
from .encode import Encode
from .decode import Decode


class Agent(nn.Module):

    def __init__(self, nn_args):
        super(Agent, self).__init__()

        self.nn_args = nn_args
        self.vars_dim = sum(nn_args['variable_dim'].values())
        self.steps_ratio = nn_args.setdefault('decode_steps_ratio', 1.0);

        logit_clips = nn_args.setdefault('decode_logit_clips', 10.0);
        if isinstance(logit_clips, str):
            self.logit_clips = [float(v) for v in logit_clips.split(',')]
        else:
            self.logit_clips = [float(logit_clips)]

        self.nn_encode = Encode(nn_args)
        self.nn_decode = Decode(nn_args)

    def nn_args_dict(self):
        return self.nn_args

    def forward(self, problem, batch_size, greedy=False, solution=None, memopt=0):
        X, K, V = self.nn_encode(problem.feats, problem.batch_size,
                                 problem.worker_num, problem.task_num, memopt)

        return self.interact(problem, X, K, V, batch_size, greedy, solution, memopt)

    def interact(self, problem, X, K, V, batch_size, greedy, solution, memopt):
        NP = problem.batch_size
        NW = problem.worker_num
        NT = problem.task_num

        sample_num = batch_size // NP
        assert sample_num > 0 and batch_size % NP == 0

        MyEnv = problem.environment
        if MyEnv is None:
            env = PyEnv(problem, batch_size, sample_num, self.nn_args)
        else:
            env = MyEnv(str(problem.device), problem.feats, batch_size,
                        sample_num, problem.worker_num, problem.task_num)

        query = X.new_zeros(batch_size, X.size(-1))
        state1 = X.new_zeros(batch_size, X.size(-1))
        state2 = X.new_zeros(batch_size, X.size(-1))

        p_list = []
        NULL = X.new_ones(0)
        p_index = torch.div(torch.arange(batch_size, device=X.device), sample_num, rounding_mode='trunc') # torch.arange(batch_size, device=X.device) // sample_num
        if solution is not None:
            solution = solution[:, :, 0:2].to(torch.int64).permute(1, 0, 2)
            assert torch.all(solution >= 0) and solution.size(1) == batch_size
            offset = torch.tensor([0, NW, NW + NW, NW + NW + NT], device=X.device)
            chosen_list = solution[:, :, 1] + offset[solution[:, :, 0]]

            mode = 0
            sample_p = torch.rand(batch_size, device=X.device)
            for chosen in chosen_list:
                env_time = env.time()
                clip = self.logit_clips[min(env_time, len(self.logit_clips) - 1)]
                varfeat = env.make_feat() if self.vars_dim > 0 else NULL
                state1, state2, chosen_p = self.decode(X, K, V, query, state1, state2,
                                                       varfeat, env.mask(), chosen, sample_p, clip, mode, memopt)
                query = X[p_index, chosen]
                p_list.append(chosen_p)
                env.step(chosen)

            assert env.all_finished(), 'not all finished!'
        else:
            mode = 1 if greedy else 2
            min_env_time = int(self.steps_ratio * NT)
            R = torch.rand(NT * 2, batch_size, device=X.device)
            while True:
                env_time = env.time()
                if env_time > min_env_time and env_time % 3 == 0 and env.all_finished():
                    break

                clip = self.logit_clips[min(env_time, len(self.logit_clips) - 1)]
                sample_p = R[env_time % R.size(0)]
                chosen = X.new_empty(batch_size, dtype=torch.int64)
                varfeat = env.make_feat() if self.vars_dim > 0 else NULL
                state1, state2, chosen_p = self.decode(X, K, V, query, state1, state2,
                                                       varfeat, env.mask(), chosen, sample_p, clip, mode, memopt)
                query = X[p_index, chosen]
                p_list.append(chosen_p)
                env.step(chosen)

        env.finalize()
        return env, torch.stack(p_list, 1)

    def decode(self, X, K, V, query, state1, state2, varfeat, mask, chosen, sample_p, clip, mode, memopt):
        run_fn = self.decode_fn(clip, mode, memopt)
        if self.training and memopt > 3:
            return checkpoint(run_fn, X, K, V, query, state1, state2, varfeat, mask, chosen, sample_p)
        else:
            return run_fn(X, K, V, query, state1, state2, varfeat, mask, chosen, sample_p)

    def decode_fn(self, clip, mode, memopt):
        memopt = 0 if memopt > 3 else memopt

        def run_fn(X, K, V, query, state1, state2, varfeat, mask, chosen, sample_p):
            return self.nn_decode(X, K, V, query, state1, state2,
                                  varfeat, mask, chosen, sample_p, clip, mode, memopt)

        return run_fn


def parse_nn_args(problem, nn_args):
    worker_dim = OrderedDict()
    task_dim = OrderedDict()
    edge_dim = OrderedDict()
    variable_dim = OrderedDict()
    embed_dict = OrderedDict()

    def set_dim_by_name(name, k, dim):
        if name.startswith("worker_task_"):
            edge_dim[k] = dim
        elif name.startswith("worker_"):
            worker_dim[k] = dim
        elif name.startswith("task_"):
            task_dim[k] = dim
        elif name.endswith("_matrix"):
            edge_dim[k] = dim
        else:
            raise Exception("attribute can't be feature: {}".format(k))

    feature_dict = make_feat_dict(problem)
    variables = [var(problem, problem.batch_size, 1) for var in problem.variables]
    variable_dict = dict([(var.name, var) for var in variables])
    for k, f in feature_dict.items():
        if isinstance(f, VariableFeature):
            var = variable_dict[f.name]
            assert hasattr(var, 'make_feat'), \
                "{} cann't be variable feature, name:{}".format(type(var).__name__, k)
            v = var.make_feat()
            if v.dim() == 2:
                variable_dim[k] = 1
            else:
                variable_dim[k] = v.size(-1)
        elif isinstance(f, SparseLocalFeature):
            edge_dim[k] = 1
            set_dim_by_name(f.value, k, 1)
        elif isinstance(f, LocalFeature):
            edge_dim[k] = 1
            set_dim_by_name(f.name, k, 1)
        elif isinstance(f, LocalCategory):
            edge_dim[k] = 1
        elif isinstance(f, GlobalCategory):
            set_dim_by_name(f.name, k, nn_args.setdefault('encode_hidden_dim', 128))
            embed_dict[k] = f.size
        elif isinstance(f, ContinuousFeature):
            v = problem.feats[k]
            if k.startswith("worker_task_") or k.endswith("_matrix"):
                simple_dim = 3
            else:
                simple_dim = 2

            if v.dim() == simple_dim:
                set_dim_by_name(f.name, k, 1)
            else:
                set_dim_by_name(f.name, k, v.size(-1))
        else:
            raise Exception("unsupported feature type: {}".format(type(f)))

    nn_args['worker_dim'] = worker_dim
    nn_args['task_dim'] = task_dim
    nn_args['edge_dim'] = edge_dim
    nn_args['variable_dim'] = variable_dim
    nn_args['embed_dict'] = embed_dict
    nn_args['feature_dict'] = feature_dict
    return nn_args


def make_feat_dict(problem):
    feature_dict = OrderedDict()

    def add(k, f):
        _f = feature_dict.get(k)
        if _f is None or _f == f:
            feature_dict[k] = f
        else:
            "duplicated feature, name: {}, feature1: {}, feature2: {}".format(k, _f, f)

    for f in problem.features:
        if isinstance(f, VariableFeature):
            add(':'.join(['var', f.name]), f)
        elif isinstance(f, SparseLocalFeature):
            add(':'.join([f.index, f.value]), f)
        else:
            add(f.name, f)

    return feature_dict