File size: 3,957 Bytes
0b8359d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

"""Tasks that test correctness of algorithms."""

from six.moves import xrange
from common import reward as reward_lib  # brain coder
from single_task import misc  # brain coder


class BasicTaskManager(object):
  """Wraps a generic reward function."""

  def __init__(self, reward_fn):
    self.reward_fn = reward_fn
    self.good_reward = 1.0

  def _score_string(self, string):
    actions = misc.bf_string_to_tokens(string)
    reward, correct = self.reward_fn(actions)
    return misc.RewardInfo(
        episode_rewards=[0.0] * (len(string) - 1) + [reward],
        input_case=None,
        correct_output=None,
        code_output=actions,
        input_type=None,
        output_type=misc.IOType.integer,
        reason='correct' if correct else 'wrong')

  def rl_batch(self, batch_size):
    reward_fns = [self._score_string] * batch_size
    return reward_fns


class Trie(object):
  """Trie for sequences."""
  EOS = ()

  def __init__(self):
    self.trie = {}

  def insert(self, sequence):
    d = self.trie
    for e in sequence:
      if e not in d:
        d[e] = {}
      d = d[e]
    d[self.EOS] = True   # Terminate sequence.

  def prefix_match(self, sequence):
    """Return prefix of `sequence` which exists in the trie."""
    d = self.trie
    index = 0
    for i, e in enumerate(sequence + [self.EOS]):
      index = i
      if e in d:
        d = d[e]
        if e == self.EOS:
          return sequence, True
      else:
        break
    return sequence[:index], False

  def next_choices(self, sequence):
    d = self.trie
    for e in sequence:
      if e in d:
        d = d[e]
      else:
        raise ValueError('Sequence not a prefix: %s' % (sequence,))
    return d.keys()


class HillClimbingTask(object):
  """Simple task that tests reward hill climbing ability.

  There are a set of paths (sequences of tokens) which are rewarded. The total
  reward for a path is proportional to its length, so the longest path is the
  target. Shorter paths can be dead ends.
  """

  def __init__(self):
    # Paths are sequences of sub-sequences. Here we form unique sub-sequences
    # out of 3 arbitrary ints. We use sub-sequences instead of single entities
    # to make the task harder by making the episodes last longer, i.e. more
    # for the agent to remember.
    a = (1, 2, 3)
    b = (4, 5, 6)
    c = (7, 8, 7)
    d = (6, 5, 4)
    e = (3, 2, 1)
    f = (8, 5, 1)
    g = (6, 4, 2)
    h = (1, 8, 3)
    self.paths = Trie()
    self.paths.insert([a, b, h])
    self.paths.insert([a, b, c, d, e, f, g, h])
    self.paths.insert([a, b, c, d, e, b, a])
    self.paths.insert([a, b, g, h])
    self.paths.insert([a, e, f, g])
    self.correct_sequence = misc.flatten([a, b, c, d, e, f, g, h])

    def distance_fn(a, b):
      len_diff = abs(len(a) - len(b))
      return sum(reward_lib.mod_abs_diff(ai - 1, bi - 1, 8)
                 for ai, bi in zip(a, b)) + len_diff * 4  # 8 / 2 = 4
    self.distance_fn = distance_fn

  def __call__(self, actions):
    # Compute reward for action sequence.
    actions = [a for a in actions if a > 0]
    sequence = [tuple(actions[i: i + 3]) for i in xrange(0, len(actions), 3)]
    prefix, complete = self.paths.prefix_match(sequence)
    if complete:
      return float(len(prefix)), actions == self.correct_sequence
    if len(prefix) == len(sequence):
      return float(len(prefix)), False
    next_pred = sequence[len(prefix)]
    choices = self.paths.next_choices(prefix)
    if choices == [()]:
      return (len(prefix) - len(next_pred) / 3.0), False
    min_dist = min(self.distance_fn(c, next_pred) for c in choices)
    # +1 reward for each element in the sequence correct, plus fraction torwards
    # closest next element.
    # Maximum distance possible is num_actions * base / 2 = 3 * 8 / 2 = 12
    return (len(prefix) + (1 - min_dist / 12.0)), False