File size: 5,890 Bytes
0484cb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import re
import curses
import numpy as np
import collections
import warnings
from typing import Optional
from gym.envs.toy_text.frozen_lake import FrozenLakeEnv

warnings.filterwarnings("ignore")


class FrozenLakeEnvCustom(FrozenLakeEnv):
    def __init__(
        self,
        render_mode: Optional[str] = None,
        desc=None,
        map_name="4x4",
        is_slippery=True,
    ):
        self.curses_screen = curses.initscr()
        curses.start_color()
        curses.curs_set(0)
        self.curses_color_pairs = self.build_ncurses_color_pairs()

        # Blocking reads
        self.curses_screen.timeout(-1)

        super().__init__(
            render_mode=render_mode,
            desc=desc,
            map_name=map_name,
            is_slippery=is_slippery,
        )

    def build_ncurses_color_pairs(self):
        """
        Based on Deepmind Pycolab https://github.com/deepmind/pycolab/blob/master/pycolab/human_ui.py
        """

        color_fg = {
            " ": (0, 0, 0),
            "S": (368, 333, 388),
            "H": (309, 572, 999),
            "P": (999, 364, 0),
            "F": (500, 999, 948),
            "G": (999, 917, 298),
            "?": (368, 333, 388),
            "←": (309, 572, 999),
            "↓": (999, 364, 0),
            "β†’": (500, 999, 948),
            "↑": (999, 917, 298),
        }

        color_pair = {}

        cpair_0_fg_id, cpair_0_bg_id = curses.pair_content(0)
        ids = set(range(curses.COLORS - 1)) - {
            cpair_0_fg_id,
            cpair_0_bg_id,
        }

        # We use color IDs from large to small.
        ids = list(reversed(sorted(ids)))

        # But only those color IDs we actually need.
        ids = ids[: len(color_fg)]
        color_ids = dict(zip(color_fg.values(), ids))

        # Program these colors into curses.
        for color, cid in color_ids.items():
            curses.init_color(cid, *color)

        # Now add the default colors to the color-to-ID map.
        cpair_0_fg = curses.color_content(cpair_0_fg_id)
        cpair_0_bg = curses.color_content(cpair_0_bg_id)
        color_ids[cpair_0_fg] = cpair_0_fg_id
        color_ids[cpair_0_bg] = cpair_0_bg_id

        # The color pair IDs we'll use for all characters count up from 1; note that
        # the "default" color pair of 0 is already defined, since _color_pair is a
        # defaultdict.
        color_pair.update(
            {character: pid for pid, character in enumerate(color_fg, start=1)}
        )

        # Program these color pairs into curses, and that's all there is to do.
        for character, pid in color_pair.items():

            # Get foreground and background colors for this character. Note how in
            # the absence of a specified background color, the same color as the
            # foreground is used.
            cpair_fg = color_fg.get(character, cpair_0_fg_id)
            cpair_bg = color_fg.get(character, cpair_0_fg_id)

            # Get color IDs for those colors and initialise a color pair.
            cpair_fg_id = color_ids[cpair_fg]
            cpair_bg_id = color_ids[cpair_bg]
            curses.init_pair(pid, cpair_fg_id, cpair_bg_id)

        return color_pair

    def render_ncurses_ui(self, screen, board, color_pair, title, q_table):
        screen.erase()

        # Draw the title
        screen.addstr(0, 2, title)

        # Draw the game board
        for row_index, board_line in enumerate(board, start=1):
            screen.move(row_index, 2)
            for codepoint in "".join(list(board_line)):
                screen.addch(codepoint, curses.color_pair(color_pair[codepoint]))

        def action_to_char(action):
            if action == 0:
                return "←"
            elif action == 1:
                return "↓"
            elif action == 2:
                return "β†’"
            elif action == 3:
                return "↑"
            else:
                return "?"

        # Draw the action grid
        max_action_table = np.argmax(q_table, axis=1).reshape(4, 4)
        for row_index, row in enumerate(max_action_table, start=1):
            screen.move(row_index, 8)
            for action in row:
                char = action_to_char(action)
                screen.addch(char, curses.color_pair(color_pair[char]))

        # Draw the Q-table
        q_table_2d = q_table.reshape(4, 16)
        for row_index, row in enumerate(q_table_2d, start=1):
            screen.move(row_index, 14)
            for col_index, col in enumerate(row):
                action = col_index % 4
                char = action_to_char(action)
                screen.addstr(f" {col:.2f}", curses.color_pair(color_pair[char]))
                if action == 3:
                    screen.addstr(" ", curses.color_pair(color_pair[" "]))

        # Redraw the game screen (but in the curses memory buffer only).
        screen.noutrefresh()

    def ansi_frame_to_board(self, frame_string):
        parts = frame_string.split("\n")
        board = []
        p = "\x1b[41m"
        for part in parts[1:]:
            if len(part):
                row = re.findall(r"S|F|H|G", part)
                try:
                    row[part.index(p)] = "P"
                except:
                    pass
                board.append(row)

        return np.array(board)

    def render(self, title=None, q_table=None):
        if self.render_mode == "curses":
            frame = self._render_text()

            board = self.ansi_frame_to_board(frame)
            self.render_ncurses_ui(
                self.curses_screen, board, self.curses_color_pairs, title, q_table
            )

            # Show the screen to the user.
            curses.doupdate()
            return board

        return super().render()

    def get_expected_new_state_for_action(self, action):
        return self.P[self.s][action][1][1]