baiyanlali-zhao commited on
Commit
eaf2e33
·
1 Parent(s): 7da037c
This view is limited to 50 files because it contains too many changes.   See raw diff
.gitignore ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /misc/
2
+ /.idea/
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ share/python-wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+
32
+ # PyInstaller
33
+ # Usually these files are written by a python script from a template
34
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
35
+ *.manifest
36
+ *.spec
37
+
38
+ # Installer logs
39
+ pip-log.txt
40
+ pip-delete-this-directory.txt
41
+
42
+ # Unit test / coverage reports
43
+ htmlcov/
44
+ .tox/
45
+ .nox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ *.py,cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+ cover/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ .pybuilder/
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ # For a library or package, you might want to ignore these files since the code is
90
+ # intended to run in multiple environments; otherwise, check them in:
91
+ # .python-version
92
+
93
+ # pipenv
94
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
96
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
97
+ # install all needed dependencies.
98
+ #Pipfile.lock
99
+
100
+ # poetry
101
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
103
+ # commonly ignored for libraries.
104
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105
+ #poetry.lock
106
+
107
+ # pdm
108
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109
+ #pdm.lock
110
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111
+ # in version control.
112
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
113
+ .pdm.toml
114
+ .pdm-python
115
+ .pdm-build/
116
+
117
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
118
+ __pypackages__/
119
+
120
+ # Celery stuff
121
+ celerybeat-schedule
122
+ celerybeat.pid
123
+
124
+ # SageMath parsed files
125
+ *.sage.py
126
+
127
+ # Environments
128
+ .env
129
+ .venv
130
+ env/
131
+ venv/
132
+ ENV/
133
+ env.bak/
134
+ venv.bak/
135
+
136
+ # Spyder project settings
137
+ .spyderproject
138
+ .spyproject
139
+
140
+ # Rope project settings
141
+ .ropeproject
142
+
143
+ # mkdocs documentation
144
+ /site
145
+
146
+ # mypy
147
+ .mypy_cache/
148
+ .dmypy.json
149
+ dmypy.json
150
+
151
+ # Pyre type checker
152
+ .pyre/
153
+
154
+ # pytype static type analyzer
155
+ .pytype/
156
+
157
+ # Cython debug symbols
158
+ cython_debug/
159
+
160
+ # PyCharm
161
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
162
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
163
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
164
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
165
+ #.idea/
166
+ /generation_results
README.md CHANGED
@@ -1,12 +1,32 @@
1
- ---
2
- title: NCERL Diverse PCG
3
- emoji: 📉
4
- colorFrom: green
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 4.38.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Negatively Correlated Ensemble RL
2
+
3
+
4
+ ### Verified environment
5
+ * Python 3.9.6
6
+ * JPype 1.3.0
7
+ * dtw 1.4.0
8
+ * scipy 1.7.2
9
+ * torch 1.8.2+cu111
10
+ * numpy 1.20.3
11
+ * gym 0.21.0
12
+ * scipy 1.7.2
13
+ * Pillow 10.0.0
14
+ * matplotlib 3.6.3
15
+ * pandas 1.3.2
16
+ * sklearn 1.0.1
17
+
18
+ ### How to use
19
+
20
+ All training are launched by running `train.py` with option and arguments. For example, execute `python train.py ncesac --lbd 0.3 --m 5` will train NCERL with hyperparameters set as $\lambda = 0.3, m=5$.
21
+ Plot script is `plots.py`
22
+
23
+ * `python train.py gan`: to train a decoder which maps a continuous action to a game level segment.
24
+ * `python train.py sac`: to train a standard SAC as the policy for online game level generation
25
+ * `python train.py asyncsac`: to train a SAC with an asynchronous evaluation environment as the policy for online game level generation
26
+ * `python train.py ncesac`: to train an NCERL based on SAC as the policy for online game level generation
27
+ * `python train.py egsac`: to train an episodic generative SAC (see paper [*The fun facets of Mario: Multifaceted experience-driven PCG via reinforcement learning*](https://dl.acm.org/doi/abs/10.1145/3555858.3563282?casa_token=AHQWYSj_GyoAAAAA:MhwOltqfijP1NQj-c6NaTQikCnlNwyaMky07gCvTK5ZlSq063ew40awAcqEcw6S5zG9Sq9ZyDsspuaM)) as the policy for online game level generation
28
+ * `python train.py pmoe`: to train an episodic generative SAC (see paper [*Probabilistic Mixture-of-Experts for Efficient Deep Reinforcement Learning*](https://arxiv.org/abs/2104.09122)) as the policy for online game level generation
29
+ * `python train.py sunrise`: to train a SUNRISE (see paper [*SUNRISE: A Simple Unified Framework for Ensemble Learning in Deep Reinforcement Learning*](https://proceedings.mlr.press/v139/lee21g.html)) as the policy for online game level generation
30
+ * `python train.py dvd`: to train a DvD-SAC (see paper [*Effective Diversity in Population Based Reinforcement Learning*](https://proceedings.neurips.cc/paper_files/paper/2020/hash/d1dc3a8270a6f9394f88847d7f0050cf-Abstract.html)) as the policy for online game level generation
31
+
32
+ For the training arguments, please refer to the help `python train.py [option] --help`
analysis/generate.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import numpy as np
4
+ from src.gan.gankits import *
5
+ from src.utils.filesys import getpath
6
+ from src.utils.img import make_img_sheet
7
+ from src.utils.datastruct import RingQueue
8
+ from src.olgen.olg_policy import RLGenPolicy, RandGenPolicy
9
+ from src.smb.level import lvlhcat, save_batch
10
+
11
+
12
+ def rand_gen_levels(n=100, h=50, dest_path=''):
13
+ levels = []
14
+ latvecs = []
15
+ decoder = get_decoder('models/decoder.pth', 'cuda:0')
16
+ init_arxv = np.load(getpath('smb/init_latvecs.npy'))
17
+ for _ in range(n):
18
+ z0 = init_arxv[random.randrange(0, len(init_arxv))]
19
+ z0 = torch.tensor(z0, device='cuda:0', dtype=torch.float)
20
+ z = torch.cat([z0, sample_latvec(h, 'cuda:0')], dim=0)
21
+ lvl = lvlhcat(process_onehot(decoder(z)))
22
+ levels.append(lvl)
23
+ latvecs.append(z.cpu().numpy())
24
+ if dest_path:
25
+ save_batch(levels, dest_path)
26
+ np.save(getpath(dest_path), np.stack(latvecs))
27
+ return levels, np.stack(latvecs)
28
+
29
+ def generate_levels(policy, dest_folder='', batch_name='samples.lvls', n=200, h=50, parallel=64, save_img=False):
30
+ levels = []
31
+ latvecs = []
32
+ obs_queues = [RingQueue(policy.n) for _ in range(parallel)]
33
+ init_arxv = np.load(getpath('smb/init_latvecs.npy'))
34
+ decoder = get_decoder('models/decoder.pth', 'cuda:0')
35
+ while len(levels) < n:
36
+ veclists = [[] for _ in range(parallel)]
37
+ for queue, veclist in zip(obs_queues, veclists):
38
+ queue.clear()
39
+ init_latvec = init_arxv[random.randrange(0, len(init_arxv))]
40
+ queue.push(init_latvec)
41
+ veclist.append(init_latvec)
42
+ for _ in range(h):
43
+ obs = np.stack([np.concatenate(queue.to_list()) for queue in obs_queues])
44
+ actions = policy.step(obs)
45
+ for queue, veclist, action in zip(obs_queues, veclists, actions):
46
+ queue.push(action)
47
+ veclist.append(action)
48
+ for veclist in veclists:
49
+ latvecs.append(np.stack(veclist))
50
+ z = torch.tensor(latvecs[-1], device='cuda:0').view(-1, nz, 1, 1)
51
+ lvl = lvlhcat(process_onehot(decoder(z)))
52
+ levels.append(lvl)
53
+ # print(f'{len(levels)}/{n} generated')
54
+ if dest_folder:
55
+ os.makedirs(getpath(dest_folder), exist_ok=True)
56
+ save_batch(levels[:n], getpath(dest_folder, batch_name))
57
+ if save_img:
58
+ for i, lvl in enumerate(levels[:n]):
59
+ lvl.to_img(f'{dest_folder}/lvl-{i}.png')
60
+ return levels[:n]
61
+
62
+
63
+ def make_samples(path, n=12, h=20, space=12):
64
+ plc = RLGenPolicy.from_path(path)
65
+ levels = generate_levels(plc, n=n, h=h)
66
+ imgs = [lvl.to_img() for lvl in levels]
67
+ make_img_sheet(imgs, ncols=1, y_margin=space, save_path=f'{path}/samples.png')
68
+ pass
69
+
70
+ if __name__ == '__main__':
71
+ pass
analysis/initial_seg.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:96086cfef10b8b7993278c96fe34916e08f3566655a5f419d41593db73d93468
3
+ size 40128
analysis/tests.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import csv
3
+ import time
4
+ import random
5
+ from src.smb.level import *
6
+ from src.drl.me_reg import *
7
+ from src.drl.nets import esmb_sample
8
+ from src.utils.filesys import getpath
9
+ from src.utils.datastruct import RingQueue
10
+ from src.smb.asyncsimlt import AsycSimltPool
11
+ from src.env.environments import get_padded_obs
12
+ from src.olgen.ol_generator import VecOnlineGenerator, OnlineGenerator
13
+ from src.drl.drl_uses import load_cfgs, load_performance
14
+ from src.olgen.olg_policy import process_obs, RandGenPolicy, RLGenPolicy, EnsembleGenPolicy
15
+
16
+
17
+ def evaluate_rewards(lvls, rfunc='default', dest_path='', parallel=1, eval_pool=None):
18
+ internal_pool = eval_pool is None
19
+ if internal_pool:
20
+ eval_pool = AsycSimltPool(parallel, rfunc_name=rfunc, verbose=False, test=True)
21
+ res = []
22
+ for lvl in lvls:
23
+ eval_pool.put('evaluate', (0, str(lvl)))
24
+ buffer = eval_pool.get()
25
+ for _, item in buffer:
26
+ res.append([sum(r) for r in zip(*item.values())])
27
+ if internal_pool:
28
+ buffer = eval_pool.close()
29
+ else:
30
+ buffer = eval_pool.get(True)
31
+ for _, item in buffer:
32
+ res.append([sum(r) for r in zip(*item.values())])
33
+ if len(dest_path):
34
+ np.save(dest_path, res)
35
+ return res
36
+
37
+ def evaluate_mnd(lvls, refs, parallel=2):
38
+ eval_pool = AsycSimltPool(parallel, verbose=False, refs=[str(ref) for ref in refs])
39
+ # m, _ = len(lvls), len(refs)
40
+ res = []
41
+ for lvl in lvls:
42
+ eval_pool.put('mnd_item', str(lvl))
43
+ res += eval_pool.get()
44
+ res += eval_pool.get(wait=True)
45
+ res = np.array(res)
46
+ eval_pool.close()
47
+ return np.mean(res[:, 0]), np.mean(res[:, 1])
48
+
49
+ def evaluate_mpd(lvls, parallel=2):
50
+ task_datas = [[] for _ in range(parallel)]
51
+ for i, (A, B) in enumerate(combinations(lvls, 2)):
52
+ # lvlA, lvlB = lvls[i * 2], lvls[i * 2 + 1]
53
+ task_datas[i % parallel].append((str(A), str(B)))
54
+
55
+ hms, dtws = [], []
56
+ eval_pool = AsycSimltPool(parallel, verbose=False)
57
+ for task_data in task_datas:
58
+ eval_pool.put('mpd', task_data)
59
+ res = eval_pool.get(wait=True)
60
+ for task_hms, _ in res:
61
+ hms += task_hms
62
+ # dtws += task_dtws
63
+ return np.mean(hms) #, np.mean(dtws)
64
+
65
+ def evaluate_gen_log(path, parallel=5):
66
+ rfunc_name = load_cfgs(path, 'rfunc')
67
+ f = open(getpath(f'{path}/step_tests.csv'), 'w', newline='')
68
+ wrtr = csv.writer(f)
69
+ cols = ['step', 'r-avg', 'r-std', 'mnd-hm', 'mnd-dtw', 'mpd-hm', 'mpd-dtw', '']
70
+ wrtr.writerow(cols)
71
+ start_time = time.time()
72
+ for lvls, name in traverse_batched_level_files(f'{path}/gen_log'):
73
+ step = name[4:]
74
+ rewards = [sum(item) for item in evaluate_rewards(lvls, rfunc_name, parallel=parallel)]
75
+ r_avg, r_std = np.mean(rewards), np.std(rewards)
76
+ # mpd_hm, mpd_dtw = evaluate_mpd(lvls, parallel=parallel)
77
+ mpd = evaluate_mpd(lvls, parallel=parallel)
78
+ line = [step, r_avg, r_std, mpd, '']
79
+ wrtr.writerow(line)
80
+ f.flush()
81
+ print(
82
+ f'{path}: step{step} evaluated in {time.time()-start_time:.1f}s -- '
83
+ + '; '.join(f'{k}: {v}' for k, v in zip(cols, line))
84
+ )
85
+ f.close()
86
+ pass
87
+
88
+ def evaluate_generator(generator, nr=200, h=50, parallel=5, dest_path=None, additional_info=None, rfunc_name='default'):
89
+ if additional_info is None: additional_info = {}
90
+ ''' Test Reward '''
91
+ lvls = generator.generate(nr, h)
92
+ rewards = [sum(item) for item in evaluate_rewards(lvls, parallel=parallel, rfunc=rfunc_name)]
93
+ r_avg, r_std = np.mean(rewards), np.std(rewards)
94
+ ''' Test MPD '''
95
+ # mpd, _ = evaluate_mpd(lvls, parallel=parallel)
96
+ mpd, *_ = evaluate_mpd(generator.generate(3000*2, h), parallel=parallel)
97
+ res = {
98
+ 'r-avg': r_avg, 'r-std': r_std, 'div': mpd,
99
+ }
100
+ res.update(additional_info)
101
+ if dest_path:
102
+ with open(getpath(dest_path), 'w', newline='') as f:
103
+ keys = [k for k in res.keys()]
104
+ wrtr = csv.writer(f)
105
+ wrtr.writerow(keys + [''])
106
+ wrtr.writerow([res[k] for k in keys] + [''])
107
+ return res
108
+ pass
109
+
110
+ def evaluate_jmer(training_path, n=1000, max_parallel=None, device='cuda:0'):
111
+ init_vecs = np.load(getpath('smb/init_latvecs.npy'))
112
+ try:
113
+ m, histlen, h, gamma, me_type = load_cfgs(training_path, 'm', 'N', 'h', 'gamma', 'me_type')
114
+ except KeyError:
115
+ return 0.
116
+ mereg_func = LogWassersteinExclusion(1.) if me_type == 'logw' else WassersteinExclusion(1.)
117
+ model = torch.load(getpath(training_path, 'policy.pth'), map_location=device)
118
+ model.requires_grad_(False)
119
+ if max_parallel is None:
120
+ max_parallel = min(n, 512)
121
+ me_regs = []
122
+ obs_queues = [RingQueue(histlen) for _ in range(max_parallel)]
123
+ while len(me_regs) < n:
124
+ size = min(max_parallel, n - len(me_regs))
125
+ mereg_vals, discount = np.zeros([size]), 1.
126
+ veclists = [[] for _ in range(size)]
127
+ for queue, veclist in zip(obs_queues, veclists):
128
+ queue.clear()
129
+ init_latvec = init_vecs[random.randrange(0, len(init_vecs))]
130
+ queue.push(init_latvec)
131
+ veclist.append(init_latvec)
132
+ for _ in range(h):
133
+ obs = np.stack([get_padded_obs(queue.to_list(), histlen) for queue in obs_queues[:size]])
134
+ muss, stdss, betas = model.get_intermediate(process_obs(obs, device))
135
+ mereg_vals += discount * mereg_func.forward(muss, stdss, betas).squeeze().cpu().numpy()
136
+ discount *= gamma
137
+ actions, _ = esmb_sample(muss, stdss, betas)
138
+ for queue, veclist, action in zip(obs_queues, veclists, actions.cpu().numpy()):
139
+ queue.push(action)
140
+ veclist.append(action)
141
+ me_regs += mereg_vals.tolist()
142
+ return me_regs
143
+
144
+ def evaluate_baseline(*rfuncs, parallel=4):
145
+ nr, md, nd, h = 100, 1000, 200, 50
146
+ gen_policy = RandGenPolicy()
147
+ olgenerator = OnlineGenerator(gen_policy)
148
+ lvls, refs = olgenerator.generate(md, h), olgenerator.generate(nd, h)
149
+ divs_h, divs_js = evaluate_mnd(lvls, refs, parallel=parallel)
150
+ keys, vals = ['d-h', 'd-js'], [divs_h, divs_js]
151
+ print(f'Diversity of baseline generator: Hamming {divs_h:.2f}; TPJS {divs_js:.2f}')
152
+ for rfunc in rfuncs:
153
+ try:
154
+ print(f'Start to evaluate {rfunc}')
155
+ start_time = time.time()
156
+ lvls = olgenerator.generate(nr, h)
157
+ rewards = [sum(item) for item in evaluate_rewards(lvls, parallel=parallel, rfunc=rfunc)]
158
+ keys.append(rfunc)
159
+ vals.append(np.mean(rewards))
160
+ print(f'Evaluation for {rfunc} finished in {time.time()-start_time:.2f}s')
161
+ print(f'Evaluation results for {rfunc}: {vals[-1]:.2f}')
162
+ except AttributeError:
163
+ continue
164
+ with open(getpath('training_data', 'baselines.csv'), 'w', newline='') as f:
165
+ wrtr = csv.writer(f)
166
+ wrtr.writerow(keys)
167
+ wrtr.writerow(vals)
168
+
169
+ def sample_initial():
170
+ playable_latvecs = np.load(getpath('smb/init_latvecs.npy'))
171
+ indexes = random.sample([*range(len(playable_latvecs))], 500)
172
+ z = playable_latvecs[indexes, :]
173
+
174
+ np.save(getpath('analysis/initial_seg.npy'), z)
175
+ pass
176
+
177
+ def generate_levels_for_test(h=25):
178
+ init_set = np.load(getpath('analysis/initial_seg.npy'))
179
+ def _generte_one(policy, path):
180
+ try:
181
+ start = time.time()
182
+ generator = VecOnlineGenerator(policy, vec_num=len(init_set))
183
+ fd, _ = os.path.split(getpath(path))
184
+ os.makedirs(fd, exist_ok=True)
185
+ generator.re_init(init_set)
186
+ lvls = generator.generate(len(init_set), h, rand_init=False)
187
+ save_batch(lvls, path)
188
+ print('Save to', path, '%.2fs' % (time.time() - start))
189
+ except FileNotFoundError as e:
190
+ print(e)
191
+ for l, m in product(['0.0', '0.1', '0.2', '0.3', '0.4', '0.5'], [2, 3, 4, 5]):
192
+ for i in range(1, 6):
193
+ pi_path = f'training_data/varpm-fhp/l{l}_m{m}/t{i}'
194
+ _generte_one(RLGenPolicy.from_path(pi_path), f'test_data/varpm-fhp/l{l}_m{m}/t{i}/samples.lvls')
195
+ pi_path = f'training_data/varpm-lgp/l{l}_m{m}/t{i}'
196
+ _generte_one(RLGenPolicy.from_path(pi_path), f'test_data/varpm-lgp/l{l}_m{m}/t{i}/samples.lvls')
197
+ for algo in ['sac', 'egsac', 'asyncsac', 'pmoe']:
198
+ for i in range(1, 6):
199
+ pi_path = f'training_data/{algo}/fhp/t{i}'
200
+ _generte_one(RLGenPolicy.from_path(pi_path), f'test_data/{algo}/fhp/t{i}/samples.lvls')
201
+ pi_path = f'training_data/{algo}/lgp/t{i}'
202
+ _generte_one(RLGenPolicy.from_path(pi_path), f'test_data/{algo}/lgp/t{i}/samples.lvls')
203
+ for algo in ['sunrise', 'dvd']:
204
+ for i in range(1, 5):
205
+ pi_path = f'training_data/{algo}/fhp/t{i}'
206
+ _generte_one(EnsembleGenPolicy.from_path(pi_path), f'test_data/{algo}/fhp/t{i}/samples.lvls')
207
+ pi_path = f'training_data/{algo}/lgp/t{i}'
208
+ _generte_one(EnsembleGenPolicy.from_path(pi_path), f'test_data/{algo}/lgp/t{i}/samples.lvls')
209
+ pass
210
+
211
+
212
+ if __name__ == '__main__':
213
+ generate_levels_for_test()
app.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import gradio as gr
4
+ import os
5
+
6
+ from src.olgen.ol_generator import VecOnlineGenerator
7
+ from src.olgen.olg_game import MarioOnlineGenGame
8
+ from src.olgen.olg_policy import RLGenPolicy
9
+ from src.smb.level import save_batch
10
+ from src.utils.filesys import getpath
11
+ from src.utils.img import make_img_sheet
12
+
13
+ import torch
14
+
15
+ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
16
+
17
+ def generate_and_play():
18
+ path = 'models/example_policy'
19
+ # Generate with example policy model
20
+ N, L = 8, 10
21
+ plc = RLGenPolicy.from_path(path, device)
22
+ generator = VecOnlineGenerator(plc, g_device=device)
23
+ fd, _ = os.path.split(getpath(path))
24
+ os.makedirs(fd, exist_ok=True)
25
+
26
+ lvls = generator.generate(N, L)
27
+ # save_batch(lvls, f'{path}/samples.lvls')
28
+ imgs = [lvl.to_img() for lvl in lvls]
29
+ return imgs
30
+ # make_img_sheet(imgs, 1, save_path=f'{path}/samples.png')
31
+
32
+ # # Play with the example policy model
33
+ # game = MarioOnlineGenGame(path)
34
+ # game.play()
35
+
36
+
37
+ with gr.Blocks(title="NCERL Demo") as demo:
38
+ gallery = gr.Gallery(
39
+ label="Generated images", show_label=False, elem_id="gallery"
40
+ , columns=[3], rows=[1], object_fit="contain", height="auto")
41
+ btn = gr.Button("Generate levels", scale=0)
42
+
43
+ btn.click(generate_and_play, None, gallery)
44
+
45
+ if __name__ == "__main__":
46
+ demo.launch()
generate_and_play.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from src.olgen.ol_generator import VecOnlineGenerator
4
+ from src.olgen.olg_game import MarioOnlineGenGame
5
+ from src.olgen.olg_policy import RLGenPolicy
6
+ from src.smb.level import save_batch
7
+ from src.utils.filesys import getpath
8
+ from src.utils.img import make_img_sheet
9
+
10
+ if __name__ == '__main__':
11
+ path = 'models/example_policy'
12
+ # Generate with example policy model
13
+ N, L = 8, 10
14
+ plc = RLGenPolicy.from_path(path)
15
+ generator = VecOnlineGenerator(plc)
16
+ fd, _ = os.path.split(getpath(path))
17
+ os.makedirs(fd, exist_ok=True)
18
+
19
+ lvls = generator.generate(N, L)
20
+ save_batch(lvls, f'{path}/samples.lvls')
21
+ imgs = [lvl.to_img() for lvl in lvls]
22
+ make_img_sheet(imgs, 1, save_path=f'{path}/samples.png')
23
+
24
+ # # Play with the example policy model
25
+ # game = MarioOnlineGenGame(path)
26
+ # game.play()
27
+ pass
models/decoder.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:795903ed4957a4fc8b5a349113477643f945efe272d33a276a55671084f10051
3
+ size 1754728
models/example_policy/cfgs.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"N": 5, "gamma": 0.9, "h": 50, "rfunc": "lgp"}
models/example_policy/policy.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:95bd64a4667f1a55f73897bf1b8e9fff63d0cd2adb860ad799d180c53bc036b8
3
+ size 2430875
models/example_policy/samples.lvls ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
2
+ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
3
+ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
4
+ ---------------------------------#----------------------------------------------------------------------------------------------------------------------------------------------
5
+ ------------------------------oo-------------------------------o---------------------------------------------------------------o------------------------------------------------
6
+ -------------------------------------------------------------------------------------------------o------------------------------------------------------------------------------
7
+ --------------------------------------SSSSSSSSSS---------------------Q--------------------------------------------------------------QQQ-----------------------------QQQQ--------
8
+ --------------So----------------------------------------------------------------------------------------------------------------------------------------------------------------
9
+ ----------------------------------------------o----------------------o---------------------------------------K----------------------------------------------------------------o-
10
+ -----------#---------------------o---------------------------------------------------------------------------2------------------------------------------------------------------
11
+ ---------####--------------------------------oS------------------#---SoS-----US------------------------------U-------------------#--SSSS-----US--------tt--------##-###S-----US-
12
+ ---------####----------tt-----T------------------------TT----#--TT-----------------------TT------------B---------------TT----#T-TT---------------------tt--------##-------------
13
+ --------########-------Tt-----T------------------------TT----TT-TT-----K----------------TTT------------B---------------TT----TT-TT---------------------tt----TT--##-------------
14
+ -------#########--gggg-Tt---kkT------k-----kk-----gggg-TT---kTT-T#-k-k-g--k-----k-ty----TTT--ggg---k-gog----kkk---gggg-TT---kTT-T--k-k-g--k-k-----g----tt---kkg--##k-k-g--k-k---
15
+ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
16
+ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
17
+ ;
18
+ ------------------------------------------------------------------------------------------------------------------------------------------S----------------------------S--S-----
19
+ ----------------------------------------------------------------------S-------------------------------------------------------------------------------------------S-SSSSSS------
20
+ -----------------------------------------------------------------------------------------------------------------------------------------%%-------------------------------------
21
+ ----------------------------------S------------------------------------------------------------------------------------------------------||-------------------------------------
22
+ ----------------------------------------------------------------------------Koo---------------------------------------o------------------||----------------------------o-o-----o
23
+ -----------------------------------------------------------------------------------------------------------------------------------------||-------------------------------------
24
+ ----------------S--Q--SSoSS--SSS--o-----------------QQoo--------------SSS----SSS%---SS-----------------------U-------SSSS-------------SSSSSSSS-------------------------SS-------
25
+ S-------------------------------------------------------------------------------|--------------------------------o--------------------------------------------------------------
26
+ SSS-So---------------------------------------S-S--------------------------------|-------------o--------------------------------------------------------------K------------------
27
+ ----------------------------------------------S--------o------------------------|--------------------------------S-------------------------------------------2------------oo----
28
+ ----------------Q---QS@Q----S@SSS-------------S--------2-----U-----------------S|------------US--------------U-------------------------------SS--------------U-------------%----
29
+ ----------T--------------------------------------------tt---------------#-------|------B---------------TT----#-----------------------------------------B-------------------|----
30
+ ---------TT--------------------------------K-----------tt--------------##--#-#--|------B---------------TT---TTT---------#-------------t----------------B-------------------|----
31
+ ---------TT----#---k------k----------------b---g--gggg-tt---k--------####----#--|-gggggb--k-k-k---g----TT---kT#-------###-------------t------k-g---k-gog----kkk------------|----
32
+ ---XXX-XXXXXXXXXXXXXXXXXXXXXXXXX------XXXXXXoXXXXXXXXXXXXXXXXXXXXoXXXXXXXX--oXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX%XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX--%%%%-----|---@
33
+ ---XXX-X-XXXXXXXXXXXXXXXXXXXXXXX------XXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX@--XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX|-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX---||------|----
34
+ ;
35
+ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
36
+ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
37
+ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
38
+ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
39
+ ------------------------------oo-----------------------------------------------o------------------------------------------------------------------------------------------------
40
+ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
41
+ ----------------%S---So--------------QQ--------------------------------------------------------------Q-QQ--------S--QQSSQSSSSSSS-----QQQ----------------------------------------
42
+ ----------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------
43
+ ----------------|--------------------------------------------K--------------------------------##-------------oo------------------------------oo---------------------------------
44
+ ----------------|--------------------------------------------U-------------------o-----o-----###------g-------------------------------------------------------------------------
45
+ ----------------|----------------#---SoS-----US--------------U-------------------------tt----###--#-####Q---S@S-#------------US--------------U-------------------------------o--
46
+ ----------------|------TT-----K-TT---------------------B------K--------TT----#T--------tt---###------------------------------##--------K---------------TT----TT--------t--------
47
+ ----------------|------TT-----U-TT-----K---------------B---------------TT----TT--------tt--####-----------------------------###------------------------TT----TT--------t--------
48
+ --kk-----------g|-gggg-Tt---k-U-T#-k-k----k--------k---t----k-----gggg-TT---kTT-Tg-----tt--####----k------k-k-----ggggg----####----k-kkyk---kkk--ggg-g-TT---TTT------k-tt----#--
49
+ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX--XXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
50
+ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
51
+ ;
52
+ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
53
+ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
54
+ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
55
+ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
56
+ ------------------------------oo-----------------------------------------------------------------------------------------------o------------------------------------------------
57
+ ---------------------------------o-----------------------------------------------------------------------------------------------------------------o----------------------------
58
+ ----------------%----So--------------Q---------------QQQQ------------Qo--------------QQQQS-----#--------------------------------SooS--SS-----S-----SQQQ---------------o---------
59
+ ----------------|------------------------------------------------------------------------------#--------------------------------------------------------------------------------
60
+ ----------------|----------------------------------------------------------------------------###-------------K------------------------------------------------o-----------------
61
+ ----------------|----------------------------------------------------------------------------###-------------2---------------------------------o--------------------------------
62
+ -------------oo-|----------------#--USoS-----US--#------------------------------Q-Q----QQ----###----------------------------------UQS------------------------US--------------o--
63
+ ----------------|------TT-----K-TT---------------##----t---------------t---------------------###-------B---------------TT----#T----------------------------------------tt-------
64
+ ---------------@|------TT-----U-TT-----K---------#---------------------t--------------------####-------B---------------TT----TT----------------------------------------tt-------
65
+ ---gg----------g|-gggg-Tt---k-U-T--k-k----k------#-k--kk-----k-----k-gog----kkk---or--------####---k-gog----k-k---gggg-TT---kTT--------------------kgggg--k-k-----ggggott---kkk-
66
+ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX%%%%%-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX--XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
67
+ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-|XX--XXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
68
+ ;
69
+ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
70
+ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
71
+ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
72
+ ----------------------------------------------------------------------------------------------------SSSS------------------------------------------------------------------------
73
+ -----------------------------------------------o--------------------------------------------------------------------------------------------------------------------------------
74
+ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
75
+ --------------------QS--------------------------------Q-Q------------------------------------U--SSSSSSSSo----SSS---S@S@QQ-------%---SS------------------------------QQQQ--------
76
+ --------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------
77
+ -----------------------------------------------------------------------------K---------------------------------------o----------|----o--------o-------------------------------o-
78
+ -----------#-----------------------------------------------------------------2------------------g-------------------------------|------------------------------------g----------
79
+ ---------TT#-------------------------------------T#--S#S-----US--------------U---------------U--S--SS---------S-----USSS-----US-|----S-2-----US------------------######S-----US-
80
+ ---------TT----------------------------TT----TT-TT---------------------B---------------TT----#---------K------------------------|------K--K------------tt----#---##-------------
81
+ --------TTT----T-----------------------TT----TT------------------------B---------------TT----TT---------------------------------|------B---------------Tt----TT--##-------------
82
+ T-------TTT----T-----gg-------kg--gggg-TT---kTT-----------k--------k-gog----k-k---g----TT---kT#-------k-----k------kgggg----k---|--k-gog--k-k-----ggg--Tt----kT--##k---g--k-k---
83
+ XXXXXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
84
+ X-XXXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
85
+ ;
86
+ --------------------------------------------------------------------------------------------------------------------------S-----------------------------------------------------
87
+ -----------------------------------------------------------------------------------------------------------------S-SSSSS-SS-----------------------------------------------------
88
+ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
89
+ ------------------------------------------------------------------------------------------------SSSSSS--------------------------------------------------------------------------
90
+ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
91
+ --------------------------------------------------------------------------------------------------------------------------------------------------------------oo----------------
92
+ ------SSS---SS--------------------------------------SQo------------------------------Q--------------------------@SSSSSSSQSSSSSSS----SQo-----------------------SS----QQQQQSS--S@S
93
+ ------------------------------------------------------------------------------------------------------------------------------------------------S--------------#----------------
94
+ -------------------------------------------------------------K-----------------------o-------------------------------------------------------U--S--------------#----------------
95
+ -------------------------------------------------------------------------------------------------------o---------------g---------------------------------------#----------------
96
+ ---------S---@S------------------#--USSS-----US-----US-2-----U---------tt-----------USSS-----US--------tt--------SQ-SSSQQ----US--------2-----U----------------##Q-Q--QQQQS---o--
97
+ -----------------------tt----TT-TT---------------------K---------------tt-------T----------------------tt----T-------------------------K--K------------------###----------------
98
+ -----------------------Tt----TT-TT---------------------B---------------tt------------------------------tt----T-------------------------B-------------------#####----------------
99
+ ----ggk-----k-----gggg-Tt---kkTTT--k-k----k--------k-gog--k-k-----g----tt---kkg----k---g--k-----------ttt---kkk----k---------------k-gog--k-k--------------#####----------------
100
+ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
101
+ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
102
+ ;
103
+ -------------------------------------------------------k------------------------------------------------------------------------------------------------------------------------
104
+ -------------------------------------S------------------------------------------------------------------------------------------------------------------------------------------
105
+ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
106
+ -------SSSSSSSSS----------------------------------------------------------------------------------------------------------------------------------------------------------------
107
+ ------------------------------oo---------------------------------------------------------------o-----------o-o----o----------o--------------------------------------------------
108
+ ---------------------------------------------o-------------------------------------------------------------------------------------------------o--------------------------------
109
+ ----------------%S---So----------------Q--------------SSSSSSSSSS-------------------------------------SQSoSSS-SS----S--SSSS----o-----QQ---------S--------------------------------
110
+ ----------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------
111
+ ----------------|------------------------------------------------------------U-------------------------------------------------------------------------------K------------------
112
+ ----------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------
113
+ ----SS----------|----------------T--S##SS----US------------SSS---------------U------------------------QQQ----USS--Qo-----#------QSQ---SSSSSSS%S--------------U---------------o--
114
+ ----------------|------TT-----K-TT-----------------------------------------------------TT----TT-------------------------##-------------------|---------B---------------Tt----#--
115
+ ----------------|------TT-----U-TT-------------------------------------B---------------TT----TT------------------------###---#---------------|---------B---------------TT----TT-
116
+ ----------------|-gggg-Tt---k-U-TT---k----b------------------------k---b----kkk---gggg-TT---kTT----U------------------####--------or-k-------|-----k-gog----kkk---ggg--Tt----kT-
117
+ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXXX--XX---------X--XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXX--XXXXXXXXXX--XXXXXXXX---XXXXX%XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
118
+ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXX---X---------X--XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXX--X-XXX-XXXX--XXXXXXXX---XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
119
+ ;
120
+ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
121
+ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
122
+ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
123
+ ----------S-------------------------------------------------------------------------------------------------------------------------------------SSS-----------------------------
124
+ -----------------------------U-o-----------------------------------------------o------------------------------------------------------------------------------------------------
125
+ -------------------------------------------------o---------------------------------------------o--------------------------------------------------------------------------------
126
+ ------------------------------------------SS---S%--------------------------------------------------------------------QQ-------------QQoo-----U-------QQQo------------Qo---------
127
+ ------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------
128
+ ---------#-----T--------------------------------|------------------------------------o-------------------------------------------------------K----------------------------------
129
+ --------#------o--------------------------------|------------------------------------------------------o------------------------------------------------------------------------
130
+ -------TT------T--------------------------------|----So2-----US---------------------USSS-----@S--------tt-----------USoS-----U------U--2-----U--###--------------------------U--
131
+ -------TT----T---------TT----TT--------o--------|------K--K------------TT----TT-T----------------------tt----#---------K---------------K--------###--------------------B--------
132
+ ------#TT---TT---------TT----TT--------#t-------|------B---------------TT----TT------------------------tt----T---------B---------------B--------###--------------------B--------
133
+ -----##TT---TT----gg-#-Tt---kkT---gggg-TT--#y---|-gk-gog--k-k----ggggg-TT---kTT----k------k-------g----tt---kkT-T--k-kkb--k-k------k-gob--k-k---##----k-----------gggggb----k-k-
134
+ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
135
+ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
models/example_policy/samples.png ADDED
plots.py ADDED
@@ -0,0 +1,733 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import json
3
+ import os
4
+ import re
5
+
6
+ import numpy as np
7
+ import pandas as pds
8
+ import matplotlib
9
+ import matplotlib.pyplot as plt
10
+ from math import sqrt
11
+ import torch
12
+ from root import PRJROOT
13
+ from sklearn.manifold import TSNE
14
+ from itertools import product, chain
15
+ # from src.drl.drl_uses import load_cfgs
16
+ from src.gan.gankits import get_decoder, process_onehot
17
+ from src.gan.gans import nz
18
+ from src.smb.level import load_batch, hamming_dis, lvlhcat
19
+ from src.utils.datastruct import RingQueue
20
+ from src.utils.filesys import load_dict_json, getpath
21
+ from src.utils.img import make_img_sheet
22
+ from torch.distributions import Normal
23
+
24
+ matplotlib.rcParams["axes.formatter.limits"] = (-5, 5)
25
+
26
+
27
+ def print_compare_tab():
28
+ rand_lgp, rand_fhp, rand_divs = load_dict_json(
29
+ 'test_data/rand_policy/performance.csv', 'lgp', 'fhp', 'diversity'
30
+ )
31
+ rand_performance = {'lgp': rand_lgp, 'fhp': rand_fhp, 'diversity': rand_divs}
32
+
33
+ def _print_line(_data, minimise=False):
34
+ means = _data.mean(axis=-1)
35
+ stds = _data.std(axis=-1)
36
+ max_i, min_i = np.argmax(means), np.argmin(means)
37
+ mean_str_content = [*map(lambda x: '%.4g' % x, _data.mean(axis=-1))]
38
+ std_str_content = [*map(lambda x: '$\pm$%.3g' % x, _data.std(axis=-1))]
39
+ if minimise:
40
+ mean_str_content[min_i] = r'\textbf{%s}' % mean_str_content[min_i]
41
+ mean_str_content[max_i] = r'\textit{%s}' % mean_str_content[max_i]
42
+ std_str_content[min_i] = r'\textbf{%s}' % std_str_content[min_i]
43
+ std_str_content[max_i] = r'\textit{%s}' % std_str_content[max_i]
44
+ else:
45
+ mean_str_content[max_i] = r'\textbf{%s}' % mean_str_content[max_i]
46
+ mean_str_content[min_i] = r'\textit{%s}' % mean_str_content[min_i]
47
+ std_str_content[max_i] = r'\textbf{%s}' % std_str_content[max_i]
48
+ std_str_content[min_i] = r'\textit{%s}' % std_str_content[min_i]
49
+ print(' &', ' & '.join(mean_str_content), r'\\')
50
+ print(' & &', ' & '.join(std_str_content), r'\\')
51
+ pass
52
+
53
+ def _print_block(_task):
54
+ fds = [
55
+ f'sac/{_task}', f'egsac/{_task}', f'asyncsac/{_task}',
56
+ f'pmoe/{_task}', f'dvd/{_task}', f'sunrise/{_task}',
57
+ f'varpm-{_task}/l0.0_m5', f'varpm-{_task}/l0.1_m5', f'varpm-{_task}/l0.2_m5',
58
+ f'varpm-{_task}/l0.3_m5', f'varpm-{_task}/l0.4_m5', f'varpm-{_task}/l0.5_m5'
59
+ ]
60
+ rewards, divs = [], []
61
+ for fd in fds:
62
+ rewards.append([])
63
+ divs.append([])
64
+ # print(getpath())
65
+ for path in glob.glob(getpath('test_data', fd, '**', 'performance.csv'), recursive=True):
66
+ reward, div = load_dict_json(path, 'reward', 'diversity')
67
+ rewards[-1].append(reward)
68
+ divs[-1].append(div)
69
+ rewards = np.array(rewards)
70
+ divs = np.array(divs)
71
+
72
+ print(' & \\multirow{2}{*}{Reward}')
73
+ _print_line(rewards)
74
+ print(' \\cline{2-14}')
75
+ print(' & \\multirow{2}{*}{Diversity}')
76
+ _print_line(divs)
77
+ print(' \\cline{2-14}')
78
+ print(' & \\multirow{2}{*}{G-mean}')
79
+ gmean = np.sqrt(rewards * divs)
80
+ _print_line(gmean)
81
+
82
+ print(' \\cline{2-14}')
83
+ print(' & \\multirow{2}{*}{N-rank}')
84
+ r_rank = np.zeros_like(rewards.flatten())
85
+ r_rank[np.argsort(-rewards.flatten())] = np.linspace(1, len(r_rank), len(r_rank))
86
+
87
+ d_rank = np.zeros_like(divs.flatten())
88
+ d_rank[np.argsort(-divs.flatten())] = np.linspace(1, len(r_rank), len(r_rank))
89
+ n_rank = (r_rank.reshape([12, 5]) + d_rank.reshape([12, 5])) / (2 * 5)
90
+ _print_line(n_rank, True)
91
+
92
+ print(' \\multirow{8}{*}{MarioPuzzle}')
93
+ _print_block('fhp')
94
+ print(' \\midrule')
95
+ print(' \\multirow{8}{*}{MultiFacet}')
96
+ _print_block('lgp')
97
+ pass
98
+
99
+ def print_compare_tab_nonrl():
100
+ # rand_lgp, rand_fhp, rand_divs = load_dict_json(
101
+ # 'test_data/rand_policy/performance.csv', 'lgp', 'fhp', 'diversity'
102
+ # )
103
+ # rand_performance = {'lgp': rand_lgp, 'fhp': rand_fhp, 'diversity': rand_divs}
104
+
105
+ def _print_line(_data, minimise=False):
106
+ means = _data.mean(axis=-1)
107
+ stds = _data.std(axis=-1)
108
+ max_i, min_i = np.argmax(means), np.argmin(means)
109
+ mean_str_content = [*map(lambda x: '%.4g' % x, _data.mean(axis=-1))]
110
+ std_str_content = [*map(lambda x: '$\pm$%.3g' % x, _data.std(axis=-1))]
111
+ if minimise:
112
+ mean_str_content[min_i] = r'\textbf{%s}' % mean_str_content[min_i]
113
+ mean_str_content[max_i] = r'\textit{%s}' % mean_str_content[max_i]
114
+ std_str_content[min_i] = r'\textbf{%s}' % std_str_content[min_i]
115
+ std_str_content[max_i] = r'\textit{%s}' % std_str_content[max_i]
116
+ else:
117
+ mean_str_content[max_i] = r'\textbf{%s}' % mean_str_content[max_i]
118
+ mean_str_content[min_i] = r'\textit{%s}' % mean_str_content[min_i]
119
+ std_str_content[max_i] = r'\textbf{%s}' % std_str_content[max_i]
120
+ std_str_content[min_i] = r'\textit{%s}' % std_str_content[min_i]
121
+ print(' &', ' & '.join(mean_str_content), r'\\')
122
+ print(' & &', ' & '.join(std_str_content), r'\\')
123
+ pass
124
+
125
+ def _print_block(_task):
126
+ fds = [
127
+ f'GAN-{_task}', f'DDPM-{_task}',
128
+ f'varpm-{_task}/l0.0_m5', f'varpm-{_task}/l0.1_m5', f'varpm-{_task}/l0.2_m5',
129
+ f'varpm-{_task}/l0.3_m5', f'varpm-{_task}/l0.4_m5', f'varpm-{_task}/l0.5_m5'
130
+ ]
131
+ rewards, divs = [], []
132
+ for fd in fds:
133
+ rewards.append([])
134
+ divs.append([])
135
+ # print(getpath())
136
+ for path in glob.glob(getpath('test_data', fd, '**', 'performance.csv'), recursive=True):
137
+ reward, div = load_dict_json(path, 'reward', 'diversity')
138
+ rewards[-1].append(reward)
139
+ divs[-1].append(div)
140
+ rewards = np.array(rewards)
141
+ divs = np.array(divs)
142
+
143
+ print(' & \\multirow{2}{*}{Reward}')
144
+ _print_line(rewards)
145
+ print(' \\cline{2-10}')
146
+ print(' & \\multirow{2}{*}{Diversity}')
147
+ _print_line(divs)
148
+ print(' \\cline{2-10}')
149
+ # print(' & \\multirow{2}{*}{G-mean}')
150
+ # gmean = np.sqrt(rewards * divs)
151
+ # _print_line(gmean)
152
+ #
153
+ # print(' \\cline{2-10}')
154
+ # print(' & \\multirow{2}{*}{N-rank}')
155
+ # r_rank = np.zeros_like(rewards.flatten())
156
+ # r_rank[np.argsort(-rewards.flatten())] = np.linspace(1, len(r_rank), len(r_rank))
157
+ #
158
+ # d_rank = np.zeros_like(divs.flatten())
159
+ # d_rank[np.argsort(-divs.flatten())] = np.linspace(1, len(r_rank), len(r_rank))
160
+ # n_rank = (r_rank.reshape([8, 5]) + d_rank.reshape([8, 5])) / (2 * 5)
161
+ # _print_line(n_rank, True)
162
+
163
+ print(' \\multirow{4}{*}{MarioPuzzle}')
164
+ _print_block('fhp')
165
+ print(' \\midrule')
166
+ print(' \\multirow{4}{*}{MultiFacet}')
167
+ _print_block('lgp')
168
+ pass
169
+
170
+ def plot_cmp_learning_curves(task, save_path='', title=''):
171
+ plt.style.use('seaborn')
172
+ colors = [plt.plot([0, 1], [-1000, -1000])[0].get_color() for _ in range(6)]
173
+ plt.cla()
174
+ plt.style.use('default')
175
+
176
+ # colors = ('#5D2CAB', '#005BD4', '#007CE4', '#0097DD', '#00ADC4', '#00C1A5')
177
+ def _get_algo_data(fd):
178
+ res = []
179
+ for i in range(1, 6):
180
+ path = getpath(fd, f't{i}', 'step_tests.csv')
181
+ try:
182
+ data = pds.read_csv(path)
183
+ trajectory = [
184
+ [float(item['step']), float(item['r-avg']), float(item['diversity'])]
185
+ for _, item in data.iterrows()
186
+ ]
187
+ trajectory.sort(key=lambda x: x[0])
188
+ res.append(trajectory)
189
+ if len(trajectory) != 26:
190
+ print('Not complete (%d)/26:' % len(trajectory), path)
191
+ except FileNotFoundError:
192
+ print(path)
193
+ res = np.array(res)
194
+ # rdsum = res[:, :, 1] + res[:, :, 2]
195
+ gmean = np.sqrt(res[:, :, 1] * res[:, :, 2])
196
+ steps = res[0, :, 0]
197
+ # r_avgs = np.mean(res[:, :, 1], axis=0)
198
+ # r_stds = np.std(res[:, :, 1], axis=0)
199
+ # divs = np.mean(res[:, :, 2], axis=0)
200
+ # div_std = np.std(res[:, :, 2], axis=0)
201
+ _performances = {
202
+ 'reward': (np.mean(res[:, :, 1], axis=0), np.std(res[:, :, 1], axis=0)),
203
+ 'diversity': (np.mean(res[:, :, 2], axis=0), np.std(res[:, :, 2], axis=0)),
204
+ # 'rdsum': (np.mean(rdsum, axis=0), np.std(rdsum, axis=0)),
205
+ 'gmean': (np.mean(gmean, axis=0), np.std(gmean, axis=0)),
206
+ }
207
+ # print(_performances['gmean'])
208
+ return steps, _performances
209
+
210
+ def _plot_criterion(_ax, _criterion):
211
+ i, j, k = 0, 0, 0
212
+ for algo, (steps, _performances) in performances.items():
213
+ avgs, stds = _performances[_criterion]
214
+ if '\lambda' in algo:
215
+ ls = '-'
216
+ _c = colors[i]
217
+ i += 1
218
+ elif algo in {'SAC', 'EGSAC', 'ASAC'}:
219
+ ls = ':'
220
+ _c = colors[j]
221
+ j += 1
222
+ else:
223
+ ls = '--'
224
+ _c = colors[j]
225
+ j += 1
226
+ _ax.plot(steps, avgs, color=_c, label=algo, ls=ls)
227
+ _ax.fill_between(steps, avgs - stds, avgs + stds, color=_c, alpha=0.15)
228
+ _ax.grid(False)
229
+ # plt.plot(steps, avgs, label=algo)
230
+ # plt.plot(_performances, label=algo)
231
+ pass
232
+ _ax.set_xlabel('Time step')
233
+
234
+ fig, ax = plt.subplots(1, 3, figsize=(9.6, 3.2), dpi=250, width_ratios=[1, 1, 1])
235
+ # fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(5, 4), dpi=256)
236
+ # fig, ax1 = plt.subplots(1, 1, figsize=(8, 3), dpi=256)
237
+ # ax2 = ax1.twinx()
238
+ # fig = plt.plot(figsize=(4, 3), dpi=256)
239
+ performances = {
240
+ 'SUNRISE': _get_algo_data(f'test_data/sunrise/{task}'),
241
+ '$\lambda$=0.0': _get_algo_data(f'test_data/varpm-{task}/l0.0_m5'),
242
+ 'DvD': _get_algo_data(f'test_data/dvd/{task}'),
243
+ '$\lambda$=0.1': _get_algo_data(f'test_data/varpm-{task}/l0.1_m5'),
244
+ 'PMOE': _get_algo_data(f'test_data/pmoe/{task}'),
245
+ '$\lambda$=0.2': _get_algo_data(f'test_data/varpm-{task}/l0.2_m5'),
246
+ 'SAC': _get_algo_data(f'test_data/sac/{task}'),
247
+ '$\lambda$=0.3': _get_algo_data(f'test_data/varpm-{task}/l0.3_m5'),
248
+ 'EGSAC': _get_algo_data(f'test_data/egsac/{task}'),
249
+ '$\lambda$=0.4': _get_algo_data(f'test_data/varpm-{task}/l0.4_m5'),
250
+ 'ASAC': _get_algo_data(f'test_data/asyncsac/{task}'),
251
+ '$\lambda$=0.5': _get_algo_data(f'test_data/varpm-{task}/l0.5_m5'),
252
+ }
253
+ # _plot_algo(*_get_algo_data(glob.glob(getpath('test_data/SAC', '**', 'step_tests.csv'))), 'SAC')
254
+ # _plot_algo(*_get_algo_data(glob.glob(getpath('test_data/EGSAC', '**', 'step_tests.csv'))), 'EGSAC')
255
+ # _plot_algo(*_get_algo_data(glob.glob(getpath('test_data/AsyncSAC', '**', 'step_tests.csv'))), 'AsyncSAC')
256
+ # _plot_algo(*_get_algo_data(glob.glob(getpath('test_data/SUNRISE', '**', 'step_tests.csv'))), 'SUNRISE')
257
+ # _plot_algo(*_get_algo_data(glob.glob(getpath('test_data/DvD-ES', '**', 'step_tests.csv'))), 'DvD-ES')
258
+ # _plot_algo(*_get_algo_data(glob.glob(getpath('test_data/lbd-m-crosstest/l0.04_m5', '**', 'step_tests.csv'))), 'NCESAC')
259
+
260
+
261
+
262
+ _plot_criterion(ax[0], 'reward')
263
+ _plot_criterion(ax[1], 'diversity')
264
+ # _plot_criterion(ax[2], 'rdsum')
265
+ _plot_criterion(ax[2], 'gmean')
266
+ # ax[0].set_title(f'{title} reward')
267
+ ax[0].set_title(f'Cumulative Reward')
268
+ ax[1].set_title('Diversity Score')
269
+ # ax[2].set_title('Summation')
270
+ ax[2].set_title('G-mean')
271
+ # plt.title(title)
272
+
273
+ lines, labels = fig.axes[-1].get_legend_handles_labels()
274
+ fig.suptitle(title, fontsize=14)
275
+ plt.tight_layout(pad=0.5)
276
+ if save_path:
277
+ plt.savefig(getpath(save_path))
278
+ else:
279
+ plt.show()
280
+
281
+ plt.cla()
282
+ plt.figure(figsize=(9.6, 2.4), dpi=250)
283
+ plt.grid(False)
284
+ plt.axis('off')
285
+ plt.yticks([1.0])
286
+ plt.legend(
287
+ lines, labels, loc='lower center', ncol=6, edgecolor='white', fontsize=15,
288
+ columnspacing=0.8, borderpad=0.16, labelspacing=0.2, handlelength=2.4, handletextpad=0.3
289
+ )
290
+ plt.tight_layout(pad=0.5)
291
+ plt.show()
292
+ pass
293
+
294
+ def plot_crosstest_scatters(rfunc, xrange=None, yrange=None, title=''):
295
+ def get_pareto():
296
+ all_points = list(chain(*scatter_groups.values())) + cmp_points
297
+ res = []
298
+ for p in all_points:
299
+ non_dominated = True
300
+ for q in all_points:
301
+ if q[0] >= p[0] and q[1] >= p[1] and (q[0] > p[0] or q[1] > p[1]):
302
+ non_dominated = False
303
+ break
304
+ if non_dominated:
305
+ res.append(p)
306
+ res.sort(key=lambda item:item[0])
307
+ return np.array(res)
308
+ def _hex_color(_c):
309
+ return
310
+ scatter_groups = {}
311
+ all_lbd = set()
312
+ # Initialise
313
+ plt.style.use('seaborn-v0_8-dark-palette')
314
+ # plt.figure(figsize=(4, 4), dpi=256)
315
+ plt.figure(figsize=(2.5, 2.5), dpi=256)
316
+ plt.axes().set_axisbelow(True)
317
+
318
+ # Competitors' performances
319
+ cmp_folders = ['asyncsac', 'egsac', 'sac', 'sunrise', 'dvd', 'pmoe']
320
+ cmp_names = ['ASAC', 'EGSAC', 'SAC', 'SUNRISE', 'DvD', 'PMOE']
321
+ cmp_labels = ['A', 'E', 'S', 'R', 'D', 'M']
322
+ cmp_markers = ['2', 'x', '+', 'o', '*', 'D']
323
+ cmp_sizes = [42, 20, 32, 16, 24, 10, 10]
324
+ cmp_points = []
325
+ for name, folder, label, mk, s in zip(cmp_names, cmp_folders, cmp_labels, cmp_markers, cmp_sizes):
326
+ path_fmt = getpath('test_data', folder, rfunc, '*', 'performance.csv')
327
+ # print(path_fmt)
328
+ xs, ys = [], []
329
+ for path in glob.glob(path_fmt, recursive=True):
330
+ # print(path)
331
+ try:
332
+ x, y = load_dict_json(path, 'reward', 'diversity')
333
+ xs.append(x)
334
+ ys.append(y)
335
+ cmp_points.append([x, y])
336
+ # plt.text(x, y, label, size=7, weight='bold', va='center', ha='center', color='#202020')
337
+ except FileNotFoundError:
338
+ print(path)
339
+ if label in {'A', 'E', 'S'}:
340
+ plt.scatter(xs, ys, marker=mk, zorder=2, s=s, label=name, color='#202020')
341
+ else:
342
+ plt.scatter(
343
+ xs, ys, marker=mk, zorder=2, s=s, label=name, color=[0., 0., 0., 0.],
344
+ edgecolors='#202020', linewidths=1
345
+ )
346
+ # NCESAC performances
347
+ for path in glob.glob(getpath('test_data', f'varpm-{rfunc}', '**', 'performance.csv'), recursive=True):
348
+ try:
349
+ x, y = load_dict_json(path, 'reward', 'diversity')
350
+ key = path.split('\\')[-3]
351
+ _, mtxt = key.split('_')
352
+ ltxt, _ = key.split('_')
353
+ lbd = float(ltxt[1:])
354
+ # if mtxt in {'m2', 'm3', 'm4'}:
355
+ # continue
356
+ all_lbd.add(lbd)
357
+ if key not in scatter_groups.keys():
358
+ scatter_groups[key] = []
359
+ scatter_groups[key].append([x, y])
360
+ except Exception as e:
361
+ print(path)
362
+ print(e)
363
+
364
+ palette = plt.get_cmap('seismic')
365
+ color_x = [0.2, 0.33, 0.4, 0.61, 0.67, 0.79]
366
+ colors = {lbd: matplotlib.colors.to_hex(c) for c, lbd in zip(palette(color_x), sorted(all_lbd))}
367
+ colors = {0.0: '#150080', 0.1: '#066598', 0.2: '#01E499', 0.3: '#9FD40C', 0.4: '#F3B020', 0.5: '#FA0000'}
368
+ for lbd in sorted(all_lbd): plt.plot([-20], [-20], label=f'$\\lambda={lbd:.1f}$', lw=6, c=colors[lbd])
369
+ markers = {2: 'o', 3: '^', 4: 'D', 5: 'p', 6: 'h'}
370
+ msizes = {2: 25, 3: 25, 4: 16, 5: 28, 6: 32}
371
+ for key, group in scatter_groups.items():
372
+ ltxt, mtxt = key.split('_')
373
+ l = float(ltxt[1:])
374
+ m = int(mtxt[1:])
375
+ arr = np.array(group)
376
+ plt.scatter(
377
+ arr[:, 0], arr[:, 1], marker=markers[m], s=msizes[m], color=[0., 0., 0., 0.], zorder=2,
378
+ edgecolors=colors[l], linewidths=1
379
+ )
380
+
381
+ plt.xlim(xrange)
382
+ plt.ylim(yrange)
383
+ # plt.xlabel('Task Reward')
384
+ # plt.ylabel('Diversity')
385
+ # plt.legend(ncol=2)
386
+ # plt.legend(
387
+ # ncol=2, loc='lower left', columnspacing=1.2, borderpad=0.0,
388
+ # handlelength=1, handletextpad=0.5, framealpha=0.
389
+ # )
390
+ pareto = get_pareto()
391
+ plt.plot(
392
+ pareto[:, 0], pareto[:, 1], color='black', alpha=0.18, lw=6, zorder=3,
393
+ solid_joinstyle='round', solid_capstyle='round'
394
+ )
395
+ # plt.plot([88, 98, 98, 88, 88], [35, 35, 0.2, 0.2, 35], color='black', alpha=0.3, lw=1.5)
396
+ # plt.xticks(fontsize=16)
397
+ # plt.yticks(fontsize=16)
398
+ # plt.xticks([(1+space) * (m-mlow) + 0.5 for m in ms], [f'm={m}' for m in ms])
399
+ plt.title(title)
400
+ plt.grid()
401
+ plt.tight_layout(pad=0.4)
402
+ plt.show()
403
+
404
+ def plot_varpm_heat(task, name):
405
+ def _get_score(m, l):
406
+ fd = getpath('test_data', f'varpm-{task}', f'l{l}_m{m}')
407
+ rewards, divs = [], []
408
+ for i in range(5):
409
+ reward, div = load_dict_json(f'{fd}/t{i+1}/performance.csv', 'reward', 'diversity')
410
+ rewards.append(reward)
411
+ divs.append(div)
412
+ gmean = [sqrt(r * d) for r, d in zip(rewards, divs)]
413
+ return np.mean(rewards), np.std(rewards), \
414
+ np.mean(divs), np.std(divs), \
415
+ np.mean(gmean), np.std(gmean)
416
+
417
+ def _plot_map(avg_map, std_map, criterion):
418
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(9, 3), dpi=256, width_ratios=(1, 1))
419
+ heat1 = ax1.imshow(avg_map, cmap='spring')
420
+ heat2 = ax2.imshow(std_map, cmap='spring')
421
+ ax1.set_xlim([-0.5, 5.5])
422
+ ax1.set_xticks([0, 1, 2, 3, 4, 5], ['$\lambda$=0.0', '$\lambda$=0.1', '$\lambda$=0.2', '$\lambda$=0.3', '$\lambda$=0.4', '$\lambda$=0.5'])
423
+ ax1.set_ylim([-0.5, 3.5])
424
+ ax1.set_yticks([0, 1, 2, 3], ['m=5', 'm=4', 'm=3', 'm=2'])
425
+ ax1.set_title('Average')
426
+ for x, y in product([0, 1, 2, 3, 4, 5], [0, 1, 2, 3]):
427
+ v = avg_map[y, x]
428
+ s = '%.4f' % v
429
+ if v >= 1000: s = s[:4]
430
+ elif v >= 1: s = s[:5]
431
+ else: s = s[1:6]
432
+ ax1.text(x, y, s, va='center', ha='center')
433
+ plt.colorbar(heat1, ax=ax1, shrink=0.9)
434
+ ax2.set_xlim([-0.5, 5.5])
435
+ ax2.set_xticks([0, 1, 2, 3, 4, 5], ['$\lambda$=0.0', '$\lambda$=0.1', '$\lambda$=0.2', '$\lambda$=0.3', '$\lambda$=0.4', '$\lambda$=0.5'])
436
+ ax2.set_ylim([-0.5, 3.5])
437
+ ax2.set_yticks([0, 1, 2, 3], ['m=5', 'm=4', 'm=3', 'm=2'])
438
+ for x, y in product([0, 1, 2, 3, 4, 5], [0, 1, 2, 3]):
439
+ v = std_map[y, x]
440
+ s = '%.4f' % v
441
+ if v >= 1000: s = s[:4]
442
+ elif v >= 1: s = s[:5]
443
+ else: s = s[1:6]
444
+ ax2.text(x, y, s, va='center', ha='center')
445
+ ax2.set_title('Standard Deviation')
446
+ plt.colorbar(heat2, ax=ax2, shrink=0.9)
447
+
448
+ fig.suptitle(f'{name}: {criterion}', fontsize=14)
449
+ plt.tight_layout()
450
+ # plt.show()
451
+ plt.savefig(getpath(f'results/heat/{name}-{criterion}.png'))
452
+
453
+ r_mean_map, r_std_map, d_mean_map, d_std_map, g_mean_map, g_std_map \
454
+ = (np.zeros([4, 6], dtype=float) for _ in range(6))
455
+ ms = [2, 3, 4, 5]
456
+ ls = ['0.0', '0.1', '0.2', '0.3', '0.4', '0.5']
457
+ for i, j in product(range(4), range(6)):
458
+ r_mean, r_std, d_mean, d_std, g_mean, g_std = _get_score(ms[i], ls[j])
459
+ r_mean_map[i, j] = r_mean
460
+ r_std_map[i, j] = r_std
461
+ d_mean_map[i, j] = d_mean
462
+ d_std_map[i, j] = d_std
463
+ g_mean_map[i, j] = g_mean
464
+ g_std_map[i, j] = g_std
465
+
466
+ _plot_map(r_mean_map, r_std_map, 'Reward')
467
+ _plot_map(d_mean_map, d_std_map, 'Diversity')
468
+ _plot_map(g_mean_map, g_std_map,'G-mean')
469
+ # _plot_map(g_mean_map, g_std_map,'G-mean')
470
+
471
+ def vis_samples():
472
+ # for l, m in product(['0.0', '0.1', '0.2', '0.3', '0.4', '0.5'], [2, 3, 4, 5]):
473
+ # for i in range(1, 6):
474
+ # lvls = load_batch(f'{PRJROOT}/test_data/varpm-fhp/l{l}_m{m}/t{i}/samples.lvls')
475
+ # imgs = [lvl.to_img(save_path=None) for lvl in lvls[:10]]
476
+ # make_img_sheet(imgs, 1, save_path=f'{PRJROOT}/test_data/varpm-fhp/l{l}_m{m}/t{i}/samples.png')
477
+ # for algo in ['sac', 'egsac', 'asyncsac', 'dvd', 'sunrise', 'pmoe']:
478
+ # for i in range(1, 6):
479
+ # lvls = load_batch(f'{PRJROOT}/test_data/{algo}/fhp/t{i}/samples.lvls')
480
+ # imgs = [lvl.to_img(save_path=None) for lvl in lvls[:10]]
481
+ # make_img_sheet(imgs, 1, save_path=f'{PRJROOT}/test_data/{algo}/fhp/t{i}/samples.png')
482
+ for i in range(1, 6):
483
+ lvls = load_batch(f'{PRJROOT}/test_data/DDPM-fhp/t{i}/samples.lvls')
484
+ imgs = [lvl.to_img(save_path=None) for lvl in lvls[:10]]
485
+ make_img_sheet(imgs, 1, save_path=f'{PRJROOT}/test_data/DDPM-fhp/t{i}/samples.png')
486
+ pass
487
+ pass
488
+
489
+ def make_tsne(task, title, n=500, save_path=None):
490
+ if not os.path.exists(getpath('test_data', f'samples_dist-{task}_{n}.npy')):
491
+ samples = []
492
+ for algo in ['dvd', 'egsac', 'pmoe', 'sunrise', 'asyncsac', 'sac']:
493
+ for t in range(5):
494
+ lvls = load_batch(getpath('test_data', algo, task, f't{t+1}', 'samples.lvls'))
495
+ samples += lvls[:n]
496
+ for l in ['0.0', '0.1', '0.2', '0.3', '0.4', '0.5']:
497
+ for t in range(5):
498
+ lvls = load_batch(getpath('test_data', f'varpm-{task}', f'l{l}_m5', f't{t+1}', 'samples.lvls'))
499
+ samples += lvls[:n]
500
+ distmat = []
501
+ for a in samples:
502
+ dist_list = []
503
+ for b in samples:
504
+ dist_list.append(hamming_dis(a, b))
505
+ distmat.append(dist_list)
506
+ distmat = np.array(distmat)
507
+ np.save(getpath('test_data', f'samples_dist-{task}_{n}.npy'), distmat)
508
+
509
+ labels = (
510
+ '$\lambda$=0.0', '$\lambda$=0.1', '$\lambda$=0.2', '$\lambda$=0.3', '$\lambda$=0.4',
511
+ '$\lambda$=0.5', 'DvD', 'EGSAC', 'PMOE', 'SUNRISE', 'ASAC', 'SAC'
512
+ )
513
+ tsne = TSNE(learning_rate='auto', n_components=2, metric='precomputed')
514
+ print(np.load(getpath('test_data', f'samples_dist-{task}_{n}.npy')).shape)
515
+ data = np.load(getpath('test_data', f'samples_dist-{task}_{n}.npy'))
516
+ embx = np.array(tsne.fit_transform(data))
517
+
518
+ plt.style.use('seaborn-dark-palette')
519
+ plt.figure(figsize=(5, 5), dpi=384)
520
+ colors = [plt.plot([-1000, -1100], [0, 0])[0].get_color() for _ in range(6)]
521
+ for i in range(6):
522
+ x, y = embx[i*n*5:(i+1)*n*5, 0], embx[i*n*5:(i+1)*n*5, 1]
523
+ plt.scatter(x, y, s=10, label=labels[i], marker='x', c=colors[i])
524
+ for i in range(6, 12):
525
+ x, y = embx[i*n*5:(i+1)*n*5, 0], embx[i*n*5:(i+1)*n*5, 1]
526
+ plt.scatter(x, y, s=8, linewidths=0, label=labels[i], c=colors[i-6])
527
+ # plt.scatter(embx[100:200, 0], embx[100:200, 1], c=colors[1], s=12, linewidths=0, label='Killer')
528
+ # plt.scatter(embx[200:, 0], embx[200:, 1], c=colors[2], s=12, linewidths=0, label='Collector')
529
+ # for i in range(4):
530
+ # plt.text(embx[i+100, 0], embx[i+100, 1], str(i+1))
531
+ # plt.text(embx[i+200, 0], embx[i+200, 1], str(i+1))
532
+ # pass
533
+ # for emb, lb, c in zip(embs, labels,colors):
534
+ # plt.scatter(emb[:,0], emb[:,1], c=c, label=lb, alpha=0.15, linewidths=0, s=7)
535
+
536
+ # xspan = 1.08 * max(abs(embx[:, 0].max()), abs(embx[:, 0].min()))
537
+ # yspan = 1.08 * max(abs(embx[:, 1].max()), abs(embx[:, 1].min()))
538
+
539
+ xrange = [1.05 * embx[:, 0].min(), 1.05 * embx[:, 0].max()]
540
+ yrange = [1.05 * embx[:, 1].min(), 1.25 * embx[:, 1].max()]
541
+
542
+ plt.xlim(xrange)
543
+ plt.ylim(yrange)
544
+ plt.xticks([])
545
+ plt.yticks([])
546
+ # plt.legend(ncol=6, handletextpad=0.02, labelspacing=0.05, columnspacing=0.16)
547
+ # plt.xticks([-xspan, -0.5 * xspan, 0, 0.5 * xspan, xspan], [''] * 5)
548
+ # plt.yticks([-yspan, -0.5 * yspan, 0, 0.6 * yspan, yspan], [''] * 5)
549
+ plt.title(title)
550
+ plt.legend(loc='upper center', ncol=6, fontsize=9, handlelength=.5, handletextpad=0.5, columnspacing=0.3, framealpha=0.)
551
+ plt.tight_layout(pad=0.2)
552
+
553
+ if save_path:
554
+ plt.savefig(getpath(save_path))
555
+ else:
556
+ plt.show()
557
+
558
+ def _prob_fmt(p, digitals=3, threshold=0.001):
559
+ fmt = '%.' + str(digitals) + 'f'
560
+ if p < threshold:
561
+ return '$\\approx 0$'
562
+ else:
563
+ txt = '$%s$' % ((fmt % p)[1:])
564
+ if txt == '$.000$':
565
+ txt = '$1.00$'
566
+ return txt
567
+
568
+ def _g_fmt(v, digitals=4):
569
+ fmt = '%.' + str(digitals) + 'g'
570
+ txt = (fmt % v)
571
+ lack = digitals - len(txt.replace('-', '').replace('.', ''))
572
+ if lack > 0 and '.' not in txt:
573
+ txt += '.'
574
+ return txt + '0' * lack
575
+ pass
576
+
577
+ def print_selection_prob(path, h=15, runs=2):
578
+ s0 = 0
579
+ model = torch.load(getpath(f'{path}/policy.pth'), map_location='cpu')
580
+ model.requires_grad_(False)
581
+ model.to('cpu')
582
+ n = 11
583
+ # n = load_cfgs(path, 'N')
584
+ # print(model.m)
585
+
586
+ init_vec = np.load(getpath('analysis/initial_seg.npy'))[s0]
587
+ decoder = get_decoder(device='cpu')
588
+ obs_buffer = RingQueue(n)
589
+ for r in range(runs):
590
+ for _ in range(h): obs_buffer.push(np.zeros([nz]))
591
+ obs_buffer.push(init_vec)
592
+ level_latvecs = [init_vec]
593
+ probs = np.zeros([model.m, h])
594
+ # probs = []
595
+ selects = []
596
+ for t in range(h):
597
+ # probs.append([])
598
+ obs = torch.tensor(np.concatenate(obs_buffer.to_list(), axis=-1), dtype=torch.float).view([1, -1])
599
+ muss, stdss, betas = model.get_intermediate(torch.tensor(obs))
600
+ i = torch.multinomial(betas.squeeze(), 1).item()
601
+ # print(i)
602
+ mu, std = muss[0][i], stdss[0][i]
603
+ action = Normal(mu, std).rsample([1]).squeeze().numpy()
604
+ # print(action)
605
+ # print(mu)
606
+ # print(std)
607
+ # print(action.numpy())
608
+ obs_buffer.push(action)
609
+ level_latvecs.append(action)
610
+ # i = torch.multinomial(betas.squeeze(), 1).item()
611
+ # print(i)
612
+ probs[:, t] = betas.squeeze().numpy()
613
+ selects.append(i)
614
+ pass
615
+ onehots = decoder(torch.tensor(level_latvecs).view(-1, nz, 1, 1))
616
+ segs = process_onehot(onehots)
617
+ lvl = lvlhcat(segs)
618
+ lvl.to_img(f'figures/gen_process/run{r}-01.png')
619
+ txts = [[_prob_fmt(p) for p in row] for row in probs]
620
+ for t, i in enumerate(selects):
621
+ txts[i][t] = r'$\boldsymbol{%s}$' % txts[i][t][1:-1]
622
+ for i, txt in enumerate(txts):
623
+ print(f' & $\\beta_{i+1}$ &', ' & '.join(txt), r'\\')
624
+ print(r'\midrule')
625
+
626
+ pass
627
+
628
+ def calc_selection_freqs(task, n):
629
+ def _count_one_init():
630
+ counts = np.zeros([model.m])
631
+ # init_vec = np.load(getpath('analysis/initial_seg.npy'))
632
+ obs_buffer = RingQueue(n)
633
+ for _ in range(runs):
634
+ for _ in range(h): obs_buffer.push(np.zeros([len(init_vecs), nz]))
635
+ obs_buffer.push(init_vecs)
636
+ # level_latvecs = [init_vec]
637
+ for _ in range(h):
638
+ obs = np.concatenate(obs_buffer.to_list(), axis=-1)
639
+ obs = torch.tensor(obs, device='cuda:0', dtype=torch.float)
640
+ muss, stdss, betas = model.get_intermediate(obs)
641
+ selects = torch.multinomial(betas.squeeze(), 1).squeeze()
642
+ mus = muss[[*range(len(init_vecs))], selects, :]
643
+ stds = stdss[[*range(len(init_vecs))], selects, :]
644
+ actions = Normal(mus, stds).rsample().squeeze().cpu().numpy()
645
+ obs_buffer.push(actions)
646
+ for i in selects:
647
+ counts[i] = counts[i] + 1
648
+ return counts
649
+ # onehots = decoder(torch.tensor(level_latvecs).view(-1, nz, 1, 1))
650
+ pass
651
+ pass
652
+ init_vecs = np.load(getpath('analysis/initial_seg.npy'))
653
+ freqs = [[] for _ in range(30)]
654
+ start_line = 0
655
+ for l in ('0.0', '0.1', '0.2', '0.3', '0.4', '0.5'):
656
+ print(r' \midrule')
657
+ for t, m in product(range(1, 6), (2, 3, 4, 5)):
658
+ path = getpath(f'test_data/varpm-{task}/l{l}_m{m}/t{t}')
659
+ model = torch.load(getpath(f'{path}/policy.pth'), map_location='cuda:0')
660
+ model.requires_grad_(False)
661
+ freq = np.zeros([m])
662
+ # n = load_cfgs(path, 'N')
663
+ runs, h = 100, 25
664
+ freq += _count_one_init()
665
+ freq /= (len(init_vecs) * runs * h)
666
+ freq = np.sort(freq)[::-1]
667
+ i = start_line + t - 1
668
+ freqs[i] += freq.tolist()
669
+ print(freqs[i])
670
+ start_line += 5
671
+ print(freqs)
672
+ with open(getpath(f'analysis/select_freqs-{task}.json'), 'w') as f:
673
+ json.dump(freqs, f)
674
+
675
+ def print_selection_freq():
676
+ # task, n = 'lgp', 5
677
+ task, n = 'fhp', 11
678
+ if not os.path.exists(getpath(f'analysis/select_freqs-{task}.json')):
679
+ calc_selection_freqs(task, n)
680
+ with open(getpath(f'analysis/select_freqs-{task}.json'), 'r') as f:
681
+ freqs = json.load(f)
682
+ lbds = ['0.0', '0.1', '0.2', '0.3', '0.4', '0.5']
683
+ for i, row_data in enumerate(freqs):
684
+ if i % 5 == 0:
685
+ print(r' \midrule')
686
+ print(r' \multirow{5}{*}{$%s$}' % lbds[i//5])
687
+ txt = ' & '.join(map(_prob_fmt, row_data))
688
+ print(f' & {i%5+1} &', txt, r'\\')
689
+
690
+ def print_individual_performances(task):
691
+ for m, l in product((2, 3, 4, 5), ('0.0', '0.1', '0.2', '0.3', '0.4', '0.5')):
692
+ values = []
693
+ if l == '0.0':
694
+ print(r' \midrule')
695
+ print(r' \multirow{6}{*}{%d}' % m)
696
+ for t in range(1, 6):
697
+ path = f'test_data/varpm-{task}/l{l}_m{m}/t{t}/performance.csv'
698
+ reward, diversity = load_dict_json(path, 'reward', 'diversity')
699
+ values.append([reward, diversity])
700
+ values.sort(key=lambda item: -item[0])
701
+ values = [*chain(*values)]
702
+ txts = [_g_fmt(v) for v in values]
703
+ print(' &', f'${l}$ & ', ' & '.join(txts), r'\\')
704
+ pass
705
+
706
+ if __name__ == '__main__':
707
+ # print_selection_prob('test_data/varpm-fhp/l0.5_m5/t5')
708
+ # print_selection_prob('test_data/varpm-fhp/l0.1_m5/t5')
709
+ # print_selection_freq()
710
+ # print_compare_tab_nonrl()
711
+ # print_individual_performances('fhp')
712
+ # print('\n\n')
713
+ # print_individual_performances('lgp')
714
+
715
+ # plot_cmp_learning_curves('fhp', save_path='results/learning_curves/fhp.png', title='MarioPuzzle')
716
+ # plot_cmp_learning_curves('lgp', save_path='results/learning_curves/lgp.png', title='MultiFacet')
717
+
718
+ # plot_crosstest_scatters('fhp', title='MarioPuzzle')
719
+ # plot_crosstest_scatters('lgp', title='MultiFacet')
720
+ # # plot_crosstest_scatters('fhp', yrange=(0, 2500), xrange=(20, 70), title='MarioPuzzle')
721
+ # plot_crosstest_scatters('lgp', yrange=(0, 1500), xrange=(20, 50), title='MultiFacet')
722
+ # plot_crosstest_scatters('lgp', yrange=(0, 800), xrange=(44, 48), title=' ')
723
+
724
+
725
+ # plot_varpm_heat('fhp', 'MarioPuzzle')
726
+ # plot_varpm_heat('lgp', 'MultiFacet')
727
+
728
+ vis_samples()
729
+
730
+ # make_tsne('fhp', 'MarioPuzzle', n=100)
731
+ # make_tsne('lgp', 'MultiFacet', n=100)
732
+ pass
733
+
pyproject.toml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "ncerl"
3
+ version = "0.1.0"
4
+ description = ""
5
+ authors = ["Ziqi Wang"]
6
+ readme = "README.md"
7
+
8
+ [tool.poetry.dependencies]
9
+ python = "^3.9"
10
+ JPype1 = "1.3.0"
11
+ dtw = "1.4.0"
12
+ torch = "1.8.1"
13
+ numpy = "^2.0.0"
14
+ pillow = "10.0.0"
15
+ matplotlib = "3.6.3"
16
+ pandas = "1.3.2"
17
+
18
+
19
+ [build-system]
20
+ requires = ["poetry-core"]
21
+ build-backend = "poetry.core.masonry.api"
requirements.txt ADDED
Binary file (4.87 kB). View file
 
root.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ import os
2
+
3
+ PRJROOT = os.path.dirname(os.path.realpath(__file__))
smb/Mario-AI-Framework.jar ADDED
Binary file (206 kB). View file
 
smb/assets/#.png ADDED
smb/assets/1.png ADDED
smb/assets/2.png ADDED
smb/assets/@.png ADDED
smb/assets/B.png ADDED
smb/assets/BSP.png ADDED
smb/assets/CB1.png ADDED
smb/assets/CB2.png ADDED
smb/assets/L.png ADDED
smb/assets/ML.png ADDED
smb/assets/MM.png ADDED
smb/assets/MR.png ADDED
smb/assets/MS.png ADDED
smb/assets/Q.png ADDED
smb/assets/S.png ADDED
smb/assets/TLP.png ADDED
smb/assets/TRP.png ADDED
smb/assets/TSP.png ADDED
smb/assets/U.png ADDED
smb/assets/X.png ADDED
smb/assets/[.png ADDED
smb/assets/].png ADDED
smb/assets/chomper.png ADDED
smb/assets/g.png ADDED
smb/assets/k.png ADDED
smb/assets/o.png ADDED
smb/assets/r.png ADDED
smb/assets/stalk.png ADDED
smb/assets/wingk.png ADDED
smb/assets/wingr.png ADDED
smb/assets/y.png ADDED
smb/img/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ This folder contain the game graphics. All the graphics file has been modified by hand from the original files that was grabbed from: https://www.spriters-resource.com/nes/supermariobros/. Except for the font.gif file as it is the same as file from the MarioAI framework.
smb/img/background.png ADDED