Spaces:
Sleeping
Sleeping
baiyanlali-zhao
commited on
Commit
·
eaf2e33
1
Parent(s):
7da037c
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +166 -0
- README.md +32 -12
- analysis/generate.py +71 -0
- analysis/initial_seg.npy +3 -0
- analysis/tests.py +213 -0
- app.py +46 -0
- generate_and_play.py +27 -0
- models/decoder.pth +3 -0
- models/example_policy/cfgs.json +1 -0
- models/example_policy/policy.pth +3 -0
- models/example_policy/samples.lvls +135 -0
- models/example_policy/samples.png +0 -0
- plots.py +733 -0
- pyproject.toml +21 -0
- requirements.txt +0 -0
- root.py +3 -0
- smb/Mario-AI-Framework.jar +0 -0
- smb/assets/#.png +0 -0
- smb/assets/1.png +0 -0
- smb/assets/2.png +0 -0
- smb/assets/@.png +0 -0
- smb/assets/B.png +0 -0
- smb/assets/BSP.png +0 -0
- smb/assets/CB1.png +0 -0
- smb/assets/CB2.png +0 -0
- smb/assets/L.png +0 -0
- smb/assets/ML.png +0 -0
- smb/assets/MM.png +0 -0
- smb/assets/MR.png +0 -0
- smb/assets/MS.png +0 -0
- smb/assets/Q.png +0 -0
- smb/assets/S.png +0 -0
- smb/assets/TLP.png +0 -0
- smb/assets/TRP.png +0 -0
- smb/assets/TSP.png +0 -0
- smb/assets/U.png +0 -0
- smb/assets/X.png +0 -0
- smb/assets/[.png +0 -0
- smb/assets/].png +0 -0
- smb/assets/chomper.png +0 -0
- smb/assets/g.png +0 -0
- smb/assets/k.png +0 -0
- smb/assets/o.png +0 -0
- smb/assets/r.png +0 -0
- smb/assets/stalk.png +0 -0
- smb/assets/wingk.png +0 -0
- smb/assets/wingr.png +0 -0
- smb/assets/y.png +0 -0
- smb/img/README.md +1 -0
- smb/img/background.png +0 -0
.gitignore
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/misc/
|
2 |
+
/.idea/
|
3 |
+
|
4 |
+
# Byte-compiled / optimized / DLL files
|
5 |
+
__pycache__/
|
6 |
+
*.py[cod]
|
7 |
+
*$py.class
|
8 |
+
|
9 |
+
# C extensions
|
10 |
+
*.so
|
11 |
+
|
12 |
+
# Distribution / packaging
|
13 |
+
.Python
|
14 |
+
build/
|
15 |
+
develop-eggs/
|
16 |
+
dist/
|
17 |
+
downloads/
|
18 |
+
eggs/
|
19 |
+
.eggs/
|
20 |
+
lib/
|
21 |
+
lib64/
|
22 |
+
parts/
|
23 |
+
sdist/
|
24 |
+
var/
|
25 |
+
wheels/
|
26 |
+
share/python-wheels/
|
27 |
+
*.egg-info/
|
28 |
+
.installed.cfg
|
29 |
+
*.egg
|
30 |
+
MANIFEST
|
31 |
+
|
32 |
+
# PyInstaller
|
33 |
+
# Usually these files are written by a python script from a template
|
34 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
35 |
+
*.manifest
|
36 |
+
*.spec
|
37 |
+
|
38 |
+
# Installer logs
|
39 |
+
pip-log.txt
|
40 |
+
pip-delete-this-directory.txt
|
41 |
+
|
42 |
+
# Unit test / coverage reports
|
43 |
+
htmlcov/
|
44 |
+
.tox/
|
45 |
+
.nox/
|
46 |
+
.coverage
|
47 |
+
.coverage.*
|
48 |
+
.cache
|
49 |
+
nosetests.xml
|
50 |
+
coverage.xml
|
51 |
+
*.cover
|
52 |
+
*.py,cover
|
53 |
+
.hypothesis/
|
54 |
+
.pytest_cache/
|
55 |
+
cover/
|
56 |
+
|
57 |
+
# Translations
|
58 |
+
*.mo
|
59 |
+
*.pot
|
60 |
+
|
61 |
+
# Django stuff:
|
62 |
+
*.log
|
63 |
+
local_settings.py
|
64 |
+
db.sqlite3
|
65 |
+
db.sqlite3-journal
|
66 |
+
|
67 |
+
# Flask stuff:
|
68 |
+
instance/
|
69 |
+
.webassets-cache
|
70 |
+
|
71 |
+
# Scrapy stuff:
|
72 |
+
.scrapy
|
73 |
+
|
74 |
+
# Sphinx documentation
|
75 |
+
docs/_build/
|
76 |
+
|
77 |
+
# PyBuilder
|
78 |
+
.pybuilder/
|
79 |
+
target/
|
80 |
+
|
81 |
+
# Jupyter Notebook
|
82 |
+
.ipynb_checkpoints
|
83 |
+
|
84 |
+
# IPython
|
85 |
+
profile_default/
|
86 |
+
ipython_config.py
|
87 |
+
|
88 |
+
# pyenv
|
89 |
+
# For a library or package, you might want to ignore these files since the code is
|
90 |
+
# intended to run in multiple environments; otherwise, check them in:
|
91 |
+
# .python-version
|
92 |
+
|
93 |
+
# pipenv
|
94 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
95 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
96 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
97 |
+
# install all needed dependencies.
|
98 |
+
#Pipfile.lock
|
99 |
+
|
100 |
+
# poetry
|
101 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
102 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
103 |
+
# commonly ignored for libraries.
|
104 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
105 |
+
#poetry.lock
|
106 |
+
|
107 |
+
# pdm
|
108 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
109 |
+
#pdm.lock
|
110 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
111 |
+
# in version control.
|
112 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
113 |
+
.pdm.toml
|
114 |
+
.pdm-python
|
115 |
+
.pdm-build/
|
116 |
+
|
117 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
118 |
+
__pypackages__/
|
119 |
+
|
120 |
+
# Celery stuff
|
121 |
+
celerybeat-schedule
|
122 |
+
celerybeat.pid
|
123 |
+
|
124 |
+
# SageMath parsed files
|
125 |
+
*.sage.py
|
126 |
+
|
127 |
+
# Environments
|
128 |
+
.env
|
129 |
+
.venv
|
130 |
+
env/
|
131 |
+
venv/
|
132 |
+
ENV/
|
133 |
+
env.bak/
|
134 |
+
venv.bak/
|
135 |
+
|
136 |
+
# Spyder project settings
|
137 |
+
.spyderproject
|
138 |
+
.spyproject
|
139 |
+
|
140 |
+
# Rope project settings
|
141 |
+
.ropeproject
|
142 |
+
|
143 |
+
# mkdocs documentation
|
144 |
+
/site
|
145 |
+
|
146 |
+
# mypy
|
147 |
+
.mypy_cache/
|
148 |
+
.dmypy.json
|
149 |
+
dmypy.json
|
150 |
+
|
151 |
+
# Pyre type checker
|
152 |
+
.pyre/
|
153 |
+
|
154 |
+
# pytype static type analyzer
|
155 |
+
.pytype/
|
156 |
+
|
157 |
+
# Cython debug symbols
|
158 |
+
cython_debug/
|
159 |
+
|
160 |
+
# PyCharm
|
161 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
162 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
163 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
164 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
165 |
+
#.idea/
|
166 |
+
/generation_results
|
README.md
CHANGED
@@ -1,12 +1,32 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Negatively Correlated Ensemble RL
|
2 |
+
|
3 |
+
|
4 |
+
### Verified environment
|
5 |
+
* Python 3.9.6
|
6 |
+
* JPype 1.3.0
|
7 |
+
* dtw 1.4.0
|
8 |
+
* scipy 1.7.2
|
9 |
+
* torch 1.8.2+cu111
|
10 |
+
* numpy 1.20.3
|
11 |
+
* gym 0.21.0
|
12 |
+
* scipy 1.7.2
|
13 |
+
* Pillow 10.0.0
|
14 |
+
* matplotlib 3.6.3
|
15 |
+
* pandas 1.3.2
|
16 |
+
* sklearn 1.0.1
|
17 |
+
|
18 |
+
### How to use
|
19 |
+
|
20 |
+
All training are launched by running `train.py` with option and arguments. For example, execute `python train.py ncesac --lbd 0.3 --m 5` will train NCERL with hyperparameters set as $\lambda = 0.3, m=5$.
|
21 |
+
Plot script is `plots.py`
|
22 |
+
|
23 |
+
* `python train.py gan`: to train a decoder which maps a continuous action to a game level segment.
|
24 |
+
* `python train.py sac`: to train a standard SAC as the policy for online game level generation
|
25 |
+
* `python train.py asyncsac`: to train a SAC with an asynchronous evaluation environment as the policy for online game level generation
|
26 |
+
* `python train.py ncesac`: to train an NCERL based on SAC as the policy for online game level generation
|
27 |
+
* `python train.py egsac`: to train an episodic generative SAC (see paper [*The fun facets of Mario: Multifaceted experience-driven PCG via reinforcement learning*](https://dl.acm.org/doi/abs/10.1145/3555858.3563282?casa_token=AHQWYSj_GyoAAAAA:MhwOltqfijP1NQj-c6NaTQikCnlNwyaMky07gCvTK5ZlSq063ew40awAcqEcw6S5zG9Sq9ZyDsspuaM)) as the policy for online game level generation
|
28 |
+
* `python train.py pmoe`: to train an episodic generative SAC (see paper [*Probabilistic Mixture-of-Experts for Efficient Deep Reinforcement Learning*](https://arxiv.org/abs/2104.09122)) as the policy for online game level generation
|
29 |
+
* `python train.py sunrise`: to train a SUNRISE (see paper [*SUNRISE: A Simple Unified Framework for Ensemble Learning in Deep Reinforcement Learning*](https://proceedings.mlr.press/v139/lee21g.html)) as the policy for online game level generation
|
30 |
+
* `python train.py dvd`: to train a DvD-SAC (see paper [*Effective Diversity in Population Based Reinforcement Learning*](https://proceedings.neurips.cc/paper_files/paper/2020/hash/d1dc3a8270a6f9394f88847d7f0050cf-Abstract.html)) as the policy for online game level generation
|
31 |
+
|
32 |
+
For the training arguments, please refer to the help `python train.py [option] --help`
|
analysis/generate.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import numpy as np
|
4 |
+
from src.gan.gankits import *
|
5 |
+
from src.utils.filesys import getpath
|
6 |
+
from src.utils.img import make_img_sheet
|
7 |
+
from src.utils.datastruct import RingQueue
|
8 |
+
from src.olgen.olg_policy import RLGenPolicy, RandGenPolicy
|
9 |
+
from src.smb.level import lvlhcat, save_batch
|
10 |
+
|
11 |
+
|
12 |
+
def rand_gen_levels(n=100, h=50, dest_path=''):
|
13 |
+
levels = []
|
14 |
+
latvecs = []
|
15 |
+
decoder = get_decoder('models/decoder.pth', 'cuda:0')
|
16 |
+
init_arxv = np.load(getpath('smb/init_latvecs.npy'))
|
17 |
+
for _ in range(n):
|
18 |
+
z0 = init_arxv[random.randrange(0, len(init_arxv))]
|
19 |
+
z0 = torch.tensor(z0, device='cuda:0', dtype=torch.float)
|
20 |
+
z = torch.cat([z0, sample_latvec(h, 'cuda:0')], dim=0)
|
21 |
+
lvl = lvlhcat(process_onehot(decoder(z)))
|
22 |
+
levels.append(lvl)
|
23 |
+
latvecs.append(z.cpu().numpy())
|
24 |
+
if dest_path:
|
25 |
+
save_batch(levels, dest_path)
|
26 |
+
np.save(getpath(dest_path), np.stack(latvecs))
|
27 |
+
return levels, np.stack(latvecs)
|
28 |
+
|
29 |
+
def generate_levels(policy, dest_folder='', batch_name='samples.lvls', n=200, h=50, parallel=64, save_img=False):
|
30 |
+
levels = []
|
31 |
+
latvecs = []
|
32 |
+
obs_queues = [RingQueue(policy.n) for _ in range(parallel)]
|
33 |
+
init_arxv = np.load(getpath('smb/init_latvecs.npy'))
|
34 |
+
decoder = get_decoder('models/decoder.pth', 'cuda:0')
|
35 |
+
while len(levels) < n:
|
36 |
+
veclists = [[] for _ in range(parallel)]
|
37 |
+
for queue, veclist in zip(obs_queues, veclists):
|
38 |
+
queue.clear()
|
39 |
+
init_latvec = init_arxv[random.randrange(0, len(init_arxv))]
|
40 |
+
queue.push(init_latvec)
|
41 |
+
veclist.append(init_latvec)
|
42 |
+
for _ in range(h):
|
43 |
+
obs = np.stack([np.concatenate(queue.to_list()) for queue in obs_queues])
|
44 |
+
actions = policy.step(obs)
|
45 |
+
for queue, veclist, action in zip(obs_queues, veclists, actions):
|
46 |
+
queue.push(action)
|
47 |
+
veclist.append(action)
|
48 |
+
for veclist in veclists:
|
49 |
+
latvecs.append(np.stack(veclist))
|
50 |
+
z = torch.tensor(latvecs[-1], device='cuda:0').view(-1, nz, 1, 1)
|
51 |
+
lvl = lvlhcat(process_onehot(decoder(z)))
|
52 |
+
levels.append(lvl)
|
53 |
+
# print(f'{len(levels)}/{n} generated')
|
54 |
+
if dest_folder:
|
55 |
+
os.makedirs(getpath(dest_folder), exist_ok=True)
|
56 |
+
save_batch(levels[:n], getpath(dest_folder, batch_name))
|
57 |
+
if save_img:
|
58 |
+
for i, lvl in enumerate(levels[:n]):
|
59 |
+
lvl.to_img(f'{dest_folder}/lvl-{i}.png')
|
60 |
+
return levels[:n]
|
61 |
+
|
62 |
+
|
63 |
+
def make_samples(path, n=12, h=20, space=12):
|
64 |
+
plc = RLGenPolicy.from_path(path)
|
65 |
+
levels = generate_levels(plc, n=n, h=h)
|
66 |
+
imgs = [lvl.to_img() for lvl in levels]
|
67 |
+
make_img_sheet(imgs, ncols=1, y_margin=space, save_path=f'{path}/samples.png')
|
68 |
+
pass
|
69 |
+
|
70 |
+
if __name__ == '__main__':
|
71 |
+
pass
|
analysis/initial_seg.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:96086cfef10b8b7993278c96fe34916e08f3566655a5f419d41593db73d93468
|
3 |
+
size 40128
|
analysis/tests.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import csv
|
3 |
+
import time
|
4 |
+
import random
|
5 |
+
from src.smb.level import *
|
6 |
+
from src.drl.me_reg import *
|
7 |
+
from src.drl.nets import esmb_sample
|
8 |
+
from src.utils.filesys import getpath
|
9 |
+
from src.utils.datastruct import RingQueue
|
10 |
+
from src.smb.asyncsimlt import AsycSimltPool
|
11 |
+
from src.env.environments import get_padded_obs
|
12 |
+
from src.olgen.ol_generator import VecOnlineGenerator, OnlineGenerator
|
13 |
+
from src.drl.drl_uses import load_cfgs, load_performance
|
14 |
+
from src.olgen.olg_policy import process_obs, RandGenPolicy, RLGenPolicy, EnsembleGenPolicy
|
15 |
+
|
16 |
+
|
17 |
+
def evaluate_rewards(lvls, rfunc='default', dest_path='', parallel=1, eval_pool=None):
|
18 |
+
internal_pool = eval_pool is None
|
19 |
+
if internal_pool:
|
20 |
+
eval_pool = AsycSimltPool(parallel, rfunc_name=rfunc, verbose=False, test=True)
|
21 |
+
res = []
|
22 |
+
for lvl in lvls:
|
23 |
+
eval_pool.put('evaluate', (0, str(lvl)))
|
24 |
+
buffer = eval_pool.get()
|
25 |
+
for _, item in buffer:
|
26 |
+
res.append([sum(r) for r in zip(*item.values())])
|
27 |
+
if internal_pool:
|
28 |
+
buffer = eval_pool.close()
|
29 |
+
else:
|
30 |
+
buffer = eval_pool.get(True)
|
31 |
+
for _, item in buffer:
|
32 |
+
res.append([sum(r) for r in zip(*item.values())])
|
33 |
+
if len(dest_path):
|
34 |
+
np.save(dest_path, res)
|
35 |
+
return res
|
36 |
+
|
37 |
+
def evaluate_mnd(lvls, refs, parallel=2):
|
38 |
+
eval_pool = AsycSimltPool(parallel, verbose=False, refs=[str(ref) for ref in refs])
|
39 |
+
# m, _ = len(lvls), len(refs)
|
40 |
+
res = []
|
41 |
+
for lvl in lvls:
|
42 |
+
eval_pool.put('mnd_item', str(lvl))
|
43 |
+
res += eval_pool.get()
|
44 |
+
res += eval_pool.get(wait=True)
|
45 |
+
res = np.array(res)
|
46 |
+
eval_pool.close()
|
47 |
+
return np.mean(res[:, 0]), np.mean(res[:, 1])
|
48 |
+
|
49 |
+
def evaluate_mpd(lvls, parallel=2):
|
50 |
+
task_datas = [[] for _ in range(parallel)]
|
51 |
+
for i, (A, B) in enumerate(combinations(lvls, 2)):
|
52 |
+
# lvlA, lvlB = lvls[i * 2], lvls[i * 2 + 1]
|
53 |
+
task_datas[i % parallel].append((str(A), str(B)))
|
54 |
+
|
55 |
+
hms, dtws = [], []
|
56 |
+
eval_pool = AsycSimltPool(parallel, verbose=False)
|
57 |
+
for task_data in task_datas:
|
58 |
+
eval_pool.put('mpd', task_data)
|
59 |
+
res = eval_pool.get(wait=True)
|
60 |
+
for task_hms, _ in res:
|
61 |
+
hms += task_hms
|
62 |
+
# dtws += task_dtws
|
63 |
+
return np.mean(hms) #, np.mean(dtws)
|
64 |
+
|
65 |
+
def evaluate_gen_log(path, parallel=5):
|
66 |
+
rfunc_name = load_cfgs(path, 'rfunc')
|
67 |
+
f = open(getpath(f'{path}/step_tests.csv'), 'w', newline='')
|
68 |
+
wrtr = csv.writer(f)
|
69 |
+
cols = ['step', 'r-avg', 'r-std', 'mnd-hm', 'mnd-dtw', 'mpd-hm', 'mpd-dtw', '']
|
70 |
+
wrtr.writerow(cols)
|
71 |
+
start_time = time.time()
|
72 |
+
for lvls, name in traverse_batched_level_files(f'{path}/gen_log'):
|
73 |
+
step = name[4:]
|
74 |
+
rewards = [sum(item) for item in evaluate_rewards(lvls, rfunc_name, parallel=parallel)]
|
75 |
+
r_avg, r_std = np.mean(rewards), np.std(rewards)
|
76 |
+
# mpd_hm, mpd_dtw = evaluate_mpd(lvls, parallel=parallel)
|
77 |
+
mpd = evaluate_mpd(lvls, parallel=parallel)
|
78 |
+
line = [step, r_avg, r_std, mpd, '']
|
79 |
+
wrtr.writerow(line)
|
80 |
+
f.flush()
|
81 |
+
print(
|
82 |
+
f'{path}: step{step} evaluated in {time.time()-start_time:.1f}s -- '
|
83 |
+
+ '; '.join(f'{k}: {v}' for k, v in zip(cols, line))
|
84 |
+
)
|
85 |
+
f.close()
|
86 |
+
pass
|
87 |
+
|
88 |
+
def evaluate_generator(generator, nr=200, h=50, parallel=5, dest_path=None, additional_info=None, rfunc_name='default'):
|
89 |
+
if additional_info is None: additional_info = {}
|
90 |
+
''' Test Reward '''
|
91 |
+
lvls = generator.generate(nr, h)
|
92 |
+
rewards = [sum(item) for item in evaluate_rewards(lvls, parallel=parallel, rfunc=rfunc_name)]
|
93 |
+
r_avg, r_std = np.mean(rewards), np.std(rewards)
|
94 |
+
''' Test MPD '''
|
95 |
+
# mpd, _ = evaluate_mpd(lvls, parallel=parallel)
|
96 |
+
mpd, *_ = evaluate_mpd(generator.generate(3000*2, h), parallel=parallel)
|
97 |
+
res = {
|
98 |
+
'r-avg': r_avg, 'r-std': r_std, 'div': mpd,
|
99 |
+
}
|
100 |
+
res.update(additional_info)
|
101 |
+
if dest_path:
|
102 |
+
with open(getpath(dest_path), 'w', newline='') as f:
|
103 |
+
keys = [k for k in res.keys()]
|
104 |
+
wrtr = csv.writer(f)
|
105 |
+
wrtr.writerow(keys + [''])
|
106 |
+
wrtr.writerow([res[k] for k in keys] + [''])
|
107 |
+
return res
|
108 |
+
pass
|
109 |
+
|
110 |
+
def evaluate_jmer(training_path, n=1000, max_parallel=None, device='cuda:0'):
|
111 |
+
init_vecs = np.load(getpath('smb/init_latvecs.npy'))
|
112 |
+
try:
|
113 |
+
m, histlen, h, gamma, me_type = load_cfgs(training_path, 'm', 'N', 'h', 'gamma', 'me_type')
|
114 |
+
except KeyError:
|
115 |
+
return 0.
|
116 |
+
mereg_func = LogWassersteinExclusion(1.) if me_type == 'logw' else WassersteinExclusion(1.)
|
117 |
+
model = torch.load(getpath(training_path, 'policy.pth'), map_location=device)
|
118 |
+
model.requires_grad_(False)
|
119 |
+
if max_parallel is None:
|
120 |
+
max_parallel = min(n, 512)
|
121 |
+
me_regs = []
|
122 |
+
obs_queues = [RingQueue(histlen) for _ in range(max_parallel)]
|
123 |
+
while len(me_regs) < n:
|
124 |
+
size = min(max_parallel, n - len(me_regs))
|
125 |
+
mereg_vals, discount = np.zeros([size]), 1.
|
126 |
+
veclists = [[] for _ in range(size)]
|
127 |
+
for queue, veclist in zip(obs_queues, veclists):
|
128 |
+
queue.clear()
|
129 |
+
init_latvec = init_vecs[random.randrange(0, len(init_vecs))]
|
130 |
+
queue.push(init_latvec)
|
131 |
+
veclist.append(init_latvec)
|
132 |
+
for _ in range(h):
|
133 |
+
obs = np.stack([get_padded_obs(queue.to_list(), histlen) for queue in obs_queues[:size]])
|
134 |
+
muss, stdss, betas = model.get_intermediate(process_obs(obs, device))
|
135 |
+
mereg_vals += discount * mereg_func.forward(muss, stdss, betas).squeeze().cpu().numpy()
|
136 |
+
discount *= gamma
|
137 |
+
actions, _ = esmb_sample(muss, stdss, betas)
|
138 |
+
for queue, veclist, action in zip(obs_queues, veclists, actions.cpu().numpy()):
|
139 |
+
queue.push(action)
|
140 |
+
veclist.append(action)
|
141 |
+
me_regs += mereg_vals.tolist()
|
142 |
+
return me_regs
|
143 |
+
|
144 |
+
def evaluate_baseline(*rfuncs, parallel=4):
|
145 |
+
nr, md, nd, h = 100, 1000, 200, 50
|
146 |
+
gen_policy = RandGenPolicy()
|
147 |
+
olgenerator = OnlineGenerator(gen_policy)
|
148 |
+
lvls, refs = olgenerator.generate(md, h), olgenerator.generate(nd, h)
|
149 |
+
divs_h, divs_js = evaluate_mnd(lvls, refs, parallel=parallel)
|
150 |
+
keys, vals = ['d-h', 'd-js'], [divs_h, divs_js]
|
151 |
+
print(f'Diversity of baseline generator: Hamming {divs_h:.2f}; TPJS {divs_js:.2f}')
|
152 |
+
for rfunc in rfuncs:
|
153 |
+
try:
|
154 |
+
print(f'Start to evaluate {rfunc}')
|
155 |
+
start_time = time.time()
|
156 |
+
lvls = olgenerator.generate(nr, h)
|
157 |
+
rewards = [sum(item) for item in evaluate_rewards(lvls, parallel=parallel, rfunc=rfunc)]
|
158 |
+
keys.append(rfunc)
|
159 |
+
vals.append(np.mean(rewards))
|
160 |
+
print(f'Evaluation for {rfunc} finished in {time.time()-start_time:.2f}s')
|
161 |
+
print(f'Evaluation results for {rfunc}: {vals[-1]:.2f}')
|
162 |
+
except AttributeError:
|
163 |
+
continue
|
164 |
+
with open(getpath('training_data', 'baselines.csv'), 'w', newline='') as f:
|
165 |
+
wrtr = csv.writer(f)
|
166 |
+
wrtr.writerow(keys)
|
167 |
+
wrtr.writerow(vals)
|
168 |
+
|
169 |
+
def sample_initial():
|
170 |
+
playable_latvecs = np.load(getpath('smb/init_latvecs.npy'))
|
171 |
+
indexes = random.sample([*range(len(playable_latvecs))], 500)
|
172 |
+
z = playable_latvecs[indexes, :]
|
173 |
+
|
174 |
+
np.save(getpath('analysis/initial_seg.npy'), z)
|
175 |
+
pass
|
176 |
+
|
177 |
+
def generate_levels_for_test(h=25):
|
178 |
+
init_set = np.load(getpath('analysis/initial_seg.npy'))
|
179 |
+
def _generte_one(policy, path):
|
180 |
+
try:
|
181 |
+
start = time.time()
|
182 |
+
generator = VecOnlineGenerator(policy, vec_num=len(init_set))
|
183 |
+
fd, _ = os.path.split(getpath(path))
|
184 |
+
os.makedirs(fd, exist_ok=True)
|
185 |
+
generator.re_init(init_set)
|
186 |
+
lvls = generator.generate(len(init_set), h, rand_init=False)
|
187 |
+
save_batch(lvls, path)
|
188 |
+
print('Save to', path, '%.2fs' % (time.time() - start))
|
189 |
+
except FileNotFoundError as e:
|
190 |
+
print(e)
|
191 |
+
for l, m in product(['0.0', '0.1', '0.2', '0.3', '0.4', '0.5'], [2, 3, 4, 5]):
|
192 |
+
for i in range(1, 6):
|
193 |
+
pi_path = f'training_data/varpm-fhp/l{l}_m{m}/t{i}'
|
194 |
+
_generte_one(RLGenPolicy.from_path(pi_path), f'test_data/varpm-fhp/l{l}_m{m}/t{i}/samples.lvls')
|
195 |
+
pi_path = f'training_data/varpm-lgp/l{l}_m{m}/t{i}'
|
196 |
+
_generte_one(RLGenPolicy.from_path(pi_path), f'test_data/varpm-lgp/l{l}_m{m}/t{i}/samples.lvls')
|
197 |
+
for algo in ['sac', 'egsac', 'asyncsac', 'pmoe']:
|
198 |
+
for i in range(1, 6):
|
199 |
+
pi_path = f'training_data/{algo}/fhp/t{i}'
|
200 |
+
_generte_one(RLGenPolicy.from_path(pi_path), f'test_data/{algo}/fhp/t{i}/samples.lvls')
|
201 |
+
pi_path = f'training_data/{algo}/lgp/t{i}'
|
202 |
+
_generte_one(RLGenPolicy.from_path(pi_path), f'test_data/{algo}/lgp/t{i}/samples.lvls')
|
203 |
+
for algo in ['sunrise', 'dvd']:
|
204 |
+
for i in range(1, 5):
|
205 |
+
pi_path = f'training_data/{algo}/fhp/t{i}'
|
206 |
+
_generte_one(EnsembleGenPolicy.from_path(pi_path), f'test_data/{algo}/fhp/t{i}/samples.lvls')
|
207 |
+
pi_path = f'training_data/{algo}/lgp/t{i}'
|
208 |
+
_generte_one(EnsembleGenPolicy.from_path(pi_path), f'test_data/{algo}/lgp/t{i}/samples.lvls')
|
209 |
+
pass
|
210 |
+
|
211 |
+
|
212 |
+
if __name__ == '__main__':
|
213 |
+
generate_levels_for_test()
|
app.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
import os
|
5 |
+
|
6 |
+
from src.olgen.ol_generator import VecOnlineGenerator
|
7 |
+
from src.olgen.olg_game import MarioOnlineGenGame
|
8 |
+
from src.olgen.olg_policy import RLGenPolicy
|
9 |
+
from src.smb.level import save_batch
|
10 |
+
from src.utils.filesys import getpath
|
11 |
+
from src.utils.img import make_img_sheet
|
12 |
+
|
13 |
+
import torch
|
14 |
+
|
15 |
+
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
16 |
+
|
17 |
+
def generate_and_play():
|
18 |
+
path = 'models/example_policy'
|
19 |
+
# Generate with example policy model
|
20 |
+
N, L = 8, 10
|
21 |
+
plc = RLGenPolicy.from_path(path, device)
|
22 |
+
generator = VecOnlineGenerator(plc, g_device=device)
|
23 |
+
fd, _ = os.path.split(getpath(path))
|
24 |
+
os.makedirs(fd, exist_ok=True)
|
25 |
+
|
26 |
+
lvls = generator.generate(N, L)
|
27 |
+
# save_batch(lvls, f'{path}/samples.lvls')
|
28 |
+
imgs = [lvl.to_img() for lvl in lvls]
|
29 |
+
return imgs
|
30 |
+
# make_img_sheet(imgs, 1, save_path=f'{path}/samples.png')
|
31 |
+
|
32 |
+
# # Play with the example policy model
|
33 |
+
# game = MarioOnlineGenGame(path)
|
34 |
+
# game.play()
|
35 |
+
|
36 |
+
|
37 |
+
with gr.Blocks(title="NCERL Demo") as demo:
|
38 |
+
gallery = gr.Gallery(
|
39 |
+
label="Generated images", show_label=False, elem_id="gallery"
|
40 |
+
, columns=[3], rows=[1], object_fit="contain", height="auto")
|
41 |
+
btn = gr.Button("Generate levels", scale=0)
|
42 |
+
|
43 |
+
btn.click(generate_and_play, None, gallery)
|
44 |
+
|
45 |
+
if __name__ == "__main__":
|
46 |
+
demo.launch()
|
generate_and_play.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from src.olgen.ol_generator import VecOnlineGenerator
|
4 |
+
from src.olgen.olg_game import MarioOnlineGenGame
|
5 |
+
from src.olgen.olg_policy import RLGenPolicy
|
6 |
+
from src.smb.level import save_batch
|
7 |
+
from src.utils.filesys import getpath
|
8 |
+
from src.utils.img import make_img_sheet
|
9 |
+
|
10 |
+
if __name__ == '__main__':
|
11 |
+
path = 'models/example_policy'
|
12 |
+
# Generate with example policy model
|
13 |
+
N, L = 8, 10
|
14 |
+
plc = RLGenPolicy.from_path(path)
|
15 |
+
generator = VecOnlineGenerator(plc)
|
16 |
+
fd, _ = os.path.split(getpath(path))
|
17 |
+
os.makedirs(fd, exist_ok=True)
|
18 |
+
|
19 |
+
lvls = generator.generate(N, L)
|
20 |
+
save_batch(lvls, f'{path}/samples.lvls')
|
21 |
+
imgs = [lvl.to_img() for lvl in lvls]
|
22 |
+
make_img_sheet(imgs, 1, save_path=f'{path}/samples.png')
|
23 |
+
|
24 |
+
# # Play with the example policy model
|
25 |
+
# game = MarioOnlineGenGame(path)
|
26 |
+
# game.play()
|
27 |
+
pass
|
models/decoder.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:795903ed4957a4fc8b5a349113477643f945efe272d33a276a55671084f10051
|
3 |
+
size 1754728
|
models/example_policy/cfgs.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"N": 5, "gamma": 0.9, "h": 50, "rfunc": "lgp"}
|
models/example_policy/policy.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:95bd64a4667f1a55f73897bf1b8e9fff63d0cd2adb860ad799d180c53bc036b8
|
3 |
+
size 2430875
|
models/example_policy/samples.lvls
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
2 |
+
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
3 |
+
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
4 |
+
---------------------------------#----------------------------------------------------------------------------------------------------------------------------------------------
|
5 |
+
------------------------------oo-------------------------------o---------------------------------------------------------------o------------------------------------------------
|
6 |
+
-------------------------------------------------------------------------------------------------o------------------------------------------------------------------------------
|
7 |
+
--------------------------------------SSSSSSSSSS---------------------Q--------------------------------------------------------------QQQ-----------------------------QQQQ--------
|
8 |
+
--------------So----------------------------------------------------------------------------------------------------------------------------------------------------------------
|
9 |
+
----------------------------------------------o----------------------o---------------------------------------K----------------------------------------------------------------o-
|
10 |
+
-----------#---------------------o---------------------------------------------------------------------------2------------------------------------------------------------------
|
11 |
+
---------####--------------------------------oS------------------#---SoS-----US------------------------------U-------------------#--SSSS-----US--------tt--------##-###S-----US-
|
12 |
+
---------####----------tt-----T------------------------TT----#--TT-----------------------TT------------B---------------TT----#T-TT---------------------tt--------##-------------
|
13 |
+
--------########-------Tt-----T------------------------TT----TT-TT-----K----------------TTT------------B---------------TT----TT-TT---------------------tt----TT--##-------------
|
14 |
+
-------#########--gggg-Tt---kkT------k-----kk-----gggg-TT---kTT-T#-k-k-g--k-----k-ty----TTT--ggg---k-gog----kkk---gggg-TT---kTT-T--k-k-g--k-k-----g----tt---kkg--##k-k-g--k-k---
|
15 |
+
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
16 |
+
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
17 |
+
;
|
18 |
+
------------------------------------------------------------------------------------------------------------------------------------------S----------------------------S--S-----
|
19 |
+
----------------------------------------------------------------------S-------------------------------------------------------------------------------------------S-SSSSSS------
|
20 |
+
-----------------------------------------------------------------------------------------------------------------------------------------%%-------------------------------------
|
21 |
+
----------------------------------S------------------------------------------------------------------------------------------------------||-------------------------------------
|
22 |
+
----------------------------------------------------------------------------Koo---------------------------------------o------------------||----------------------------o-o-----o
|
23 |
+
-----------------------------------------------------------------------------------------------------------------------------------------||-------------------------------------
|
24 |
+
----------------S--Q--SSoSS--SSS--o-----------------QQoo--------------SSS----SSS%---SS-----------------------U-------SSSS-------------SSSSSSSS-------------------------SS-------
|
25 |
+
S-------------------------------------------------------------------------------|--------------------------------o--------------------------------------------------------------
|
26 |
+
SSS-So---------------------------------------S-S--------------------------------|-------------o--------------------------------------------------------------K------------------
|
27 |
+
----------------------------------------------S--------o------------------------|--------------------------------S-------------------------------------------2------------oo----
|
28 |
+
----------------Q---QS@Q----S@SSS-------------S--------2-----U-----------------S|------------US--------------U-------------------------------SS--------------U-------------%----
|
29 |
+
----------T--------------------------------------------tt---------------#-------|------B---------------TT----#-----------------------------------------B-------------------|----
|
30 |
+
---------TT--------------------------------K-----------tt--------------##--#-#--|------B---------------TT---TTT---------#-------------t----------------B-------------------|----
|
31 |
+
---------TT----#---k------k----------------b---g--gggg-tt---k--------####----#--|-gggggb--k-k-k---g----TT---kT#-------###-------------t------k-g---k-gog----kkk------------|----
|
32 |
+
---XXX-XXXXXXXXXXXXXXXXXXXXXXXXX------XXXXXXoXXXXXXXXXXXXXXXXXXXXoXXXXXXXX--oXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX%XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX--%%%%-----|---@
|
33 |
+
---XXX-X-XXXXXXXXXXXXXXXXXXXXXXX------XXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX@--XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX|-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX---||------|----
|
34 |
+
;
|
35 |
+
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
36 |
+
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
37 |
+
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
38 |
+
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
39 |
+
------------------------------oo-----------------------------------------------o------------------------------------------------------------------------------------------------
|
40 |
+
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
41 |
+
----------------%S---So--------------QQ--------------------------------------------------------------Q-QQ--------S--QQSSQSSSSSSS-----QQQ----------------------------------------
|
42 |
+
----------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------
|
43 |
+
----------------|--------------------------------------------K--------------------------------##-------------oo------------------------------oo---------------------------------
|
44 |
+
----------------|--------------------------------------------U-------------------o-----o-----###------g-------------------------------------------------------------------------
|
45 |
+
----------------|----------------#---SoS-----US--------------U-------------------------tt----###--#-####Q---S@S-#------------US--------------U-------------------------------o--
|
46 |
+
----------------|------TT-----K-TT---------------------B------K--------TT----#T--------tt---###------------------------------##--------K---------------TT----TT--------t--------
|
47 |
+
----------------|------TT-----U-TT-----K---------------B---------------TT----TT--------tt--####-----------------------------###------------------------TT----TT--------t--------
|
48 |
+
--kk-----------g|-gggg-Tt---k-U-T#-k-k----k--------k---t----k-----gggg-TT---kTT-Tg-----tt--####----k------k-k-----ggggg----####----k-kkyk---kkk--ggg-g-TT---TTT------k-tt----#--
|
49 |
+
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX--XXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
50 |
+
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
51 |
+
;
|
52 |
+
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
53 |
+
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
54 |
+
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
55 |
+
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
56 |
+
------------------------------oo-----------------------------------------------------------------------------------------------o------------------------------------------------
|
57 |
+
---------------------------------o-----------------------------------------------------------------------------------------------------------------o----------------------------
|
58 |
+
----------------%----So--------------Q---------------QQQQ------------Qo--------------QQQQS-----#--------------------------------SooS--SS-----S-----SQQQ---------------o---------
|
59 |
+
----------------|------------------------------------------------------------------------------#--------------------------------------------------------------------------------
|
60 |
+
----------------|----------------------------------------------------------------------------###-------------K------------------------------------------------o-----------------
|
61 |
+
----------------|----------------------------------------------------------------------------###-------------2---------------------------------o--------------------------------
|
62 |
+
-------------oo-|----------------#--USoS-----US--#------------------------------Q-Q----QQ----###----------------------------------UQS------------------------US--------------o--
|
63 |
+
----------------|------TT-----K-TT---------------##----t---------------t---------------------###-------B---------------TT----#T----------------------------------------tt-------
|
64 |
+
---------------@|------TT-----U-TT-----K---------#---------------------t--------------------####-------B---------------TT----TT----------------------------------------tt-------
|
65 |
+
---gg----------g|-gggg-Tt---k-U-T--k-k----k------#-k--kk-----k-----k-gog----kkk---or--------####---k-gog----k-k---gggg-TT---kTT--------------------kgggg--k-k-----ggggott---kkk-
|
66 |
+
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX%%%%%-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX--XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
67 |
+
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-|XX--XXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
68 |
+
;
|
69 |
+
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
70 |
+
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
71 |
+
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
72 |
+
----------------------------------------------------------------------------------------------------SSSS------------------------------------------------------------------------
|
73 |
+
-----------------------------------------------o--------------------------------------------------------------------------------------------------------------------------------
|
74 |
+
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
75 |
+
--------------------QS--------------------------------Q-Q------------------------------------U--SSSSSSSSo----SSS---S@S@QQ-------%---SS------------------------------QQQQ--------
|
76 |
+
--------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------
|
77 |
+
-----------------------------------------------------------------------------K---------------------------------------o----------|----o--------o-------------------------------o-
|
78 |
+
-----------#-----------------------------------------------------------------2------------------g-------------------------------|------------------------------------g----------
|
79 |
+
---------TT#-------------------------------------T#--S#S-----US--------------U---------------U--S--SS---------S-----USSS-----US-|----S-2-----US------------------######S-----US-
|
80 |
+
---------TT----------------------------TT----TT-TT---------------------B---------------TT----#---------K------------------------|------K--K------------tt----#---##-------------
|
81 |
+
--------TTT----T-----------------------TT----TT------------------------B---------------TT----TT---------------------------------|------B---------------Tt----TT--##-------------
|
82 |
+
T-------TTT----T-----gg-------kg--gggg-TT---kTT-----------k--------k-gog----k-k---g----TT---kT#-------k-----k------kgggg----k---|--k-gog--k-k-----ggg--Tt----kT--##k---g--k-k---
|
83 |
+
XXXXXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
84 |
+
X-XXXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
85 |
+
;
|
86 |
+
--------------------------------------------------------------------------------------------------------------------------S-----------------------------------------------------
|
87 |
+
-----------------------------------------------------------------------------------------------------------------S-SSSSS-SS-----------------------------------------------------
|
88 |
+
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
89 |
+
------------------------------------------------------------------------------------------------SSSSSS--------------------------------------------------------------------------
|
90 |
+
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
91 |
+
--------------------------------------------------------------------------------------------------------------------------------------------------------------oo----------------
|
92 |
+
------SSS---SS--------------------------------------SQo------------------------------Q--------------------------@SSSSSSSQSSSSSSS----SQo-----------------------SS----QQQQQSS--S@S
|
93 |
+
------------------------------------------------------------------------------------------------------------------------------------------------S--------------#----------------
|
94 |
+
-------------------------------------------------------------K-----------------------o-------------------------------------------------------U--S--------------#----------------
|
95 |
+
-------------------------------------------------------------------------------------------------------o---------------g---------------------------------------#----------------
|
96 |
+
---------S---@S------------------#--USSS-----US-----US-2-----U---------tt-----------USSS-----US--------tt--------SQ-SSSQQ----US--------2-----U----------------##Q-Q--QQQQS---o--
|
97 |
+
-----------------------tt----TT-TT---------------------K---------------tt-------T----------------------tt----T-------------------------K--K------------------###----------------
|
98 |
+
-----------------------Tt----TT-TT---------------------B---------------tt------------------------------tt----T-------------------------B-------------------#####----------------
|
99 |
+
----ggk-----k-----gggg-Tt---kkTTT--k-k----k--------k-gog--k-k-----g----tt---kkg----k---g--k-----------ttt---kkk----k---------------k-gog--k-k--------------#####----------------
|
100 |
+
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
101 |
+
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
102 |
+
;
|
103 |
+
-------------------------------------------------------k------------------------------------------------------------------------------------------------------------------------
|
104 |
+
-------------------------------------S------------------------------------------------------------------------------------------------------------------------------------------
|
105 |
+
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
106 |
+
-------SSSSSSSSS----------------------------------------------------------------------------------------------------------------------------------------------------------------
|
107 |
+
------------------------------oo---------------------------------------------------------------o-----------o-o----o----------o--------------------------------------------------
|
108 |
+
---------------------------------------------o-------------------------------------------------------------------------------------------------o--------------------------------
|
109 |
+
----------------%S---So----------------Q--------------SSSSSSSSSS-------------------------------------SQSoSSS-SS----S--SSSS----o-----QQ---------S--------------------------------
|
110 |
+
----------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------
|
111 |
+
----------------|------------------------------------------------------------U-------------------------------------------------------------------------------K------------------
|
112 |
+
----------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------
|
113 |
+
----SS----------|----------------T--S##SS----US------------SSS---------------U------------------------QQQ----USS--Qo-----#------QSQ---SSSSSSS%S--------------U---------------o--
|
114 |
+
----------------|------TT-----K-TT-----------------------------------------------------TT----TT-------------------------##-------------------|---------B---------------Tt----#--
|
115 |
+
----------------|------TT-----U-TT-------------------------------------B---------------TT----TT------------------------###---#---------------|---------B---------------TT----TT-
|
116 |
+
----------------|-gggg-Tt---k-U-TT---k----b------------------------k---b----kkk---gggg-TT---kTT----U------------------####--------or-k-------|-----k-gog----kkk---ggg--Tt----kT-
|
117 |
+
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXXX--XX---------X--XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXX--XXXXXXXXXX--XXXXXXXX---XXXXX%XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
118 |
+
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXX---X---------X--XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXX--X-XXX-XXXX--XXXXXXXX---XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
119 |
+
;
|
120 |
+
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
121 |
+
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
122 |
+
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
123 |
+
----------S-------------------------------------------------------------------------------------------------------------------------------------SSS-----------------------------
|
124 |
+
-----------------------------U-o-----------------------------------------------o------------------------------------------------------------------------------------------------
|
125 |
+
-------------------------------------------------o---------------------------------------------o--------------------------------------------------------------------------------
|
126 |
+
------------------------------------------SS---S%--------------------------------------------------------------------QQ-------------QQoo-----U-------QQQo------------Qo---------
|
127 |
+
------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------
|
128 |
+
---------#-----T--------------------------------|------------------------------------o-------------------------------------------------------K----------------------------------
|
129 |
+
--------#------o--------------------------------|------------------------------------------------------o------------------------------------------------------------------------
|
130 |
+
-------TT------T--------------------------------|----So2-----US---------------------USSS-----@S--------tt-----------USoS-----U------U--2-----U--###--------------------------U--
|
131 |
+
-------TT----T---------TT----TT--------o--------|------K--K------------TT----TT-T----------------------tt----#---------K---------------K--------###--------------------B--------
|
132 |
+
------#TT---TT---------TT----TT--------#t-------|------B---------------TT----TT------------------------tt----T---------B---------------B--------###--------------------B--------
|
133 |
+
-----##TT---TT----gg-#-Tt---kkT---gggg-TT--#y---|-gk-gog--k-k----ggggg-TT---kTT----k------k-------g----tt---kkT-T--k-kkb--k-k------k-gob--k-k---##----k-----------gggggb----k-k-
|
134 |
+
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
135 |
+
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
models/example_policy/samples.png
ADDED
plots.py
ADDED
@@ -0,0 +1,733 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import pandas as pds
|
8 |
+
import matplotlib
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
from math import sqrt
|
11 |
+
import torch
|
12 |
+
from root import PRJROOT
|
13 |
+
from sklearn.manifold import TSNE
|
14 |
+
from itertools import product, chain
|
15 |
+
# from src.drl.drl_uses import load_cfgs
|
16 |
+
from src.gan.gankits import get_decoder, process_onehot
|
17 |
+
from src.gan.gans import nz
|
18 |
+
from src.smb.level import load_batch, hamming_dis, lvlhcat
|
19 |
+
from src.utils.datastruct import RingQueue
|
20 |
+
from src.utils.filesys import load_dict_json, getpath
|
21 |
+
from src.utils.img import make_img_sheet
|
22 |
+
from torch.distributions import Normal
|
23 |
+
|
24 |
+
matplotlib.rcParams["axes.formatter.limits"] = (-5, 5)
|
25 |
+
|
26 |
+
|
27 |
+
def print_compare_tab():
|
28 |
+
rand_lgp, rand_fhp, rand_divs = load_dict_json(
|
29 |
+
'test_data/rand_policy/performance.csv', 'lgp', 'fhp', 'diversity'
|
30 |
+
)
|
31 |
+
rand_performance = {'lgp': rand_lgp, 'fhp': rand_fhp, 'diversity': rand_divs}
|
32 |
+
|
33 |
+
def _print_line(_data, minimise=False):
|
34 |
+
means = _data.mean(axis=-1)
|
35 |
+
stds = _data.std(axis=-1)
|
36 |
+
max_i, min_i = np.argmax(means), np.argmin(means)
|
37 |
+
mean_str_content = [*map(lambda x: '%.4g' % x, _data.mean(axis=-1))]
|
38 |
+
std_str_content = [*map(lambda x: '$\pm$%.3g' % x, _data.std(axis=-1))]
|
39 |
+
if minimise:
|
40 |
+
mean_str_content[min_i] = r'\textbf{%s}' % mean_str_content[min_i]
|
41 |
+
mean_str_content[max_i] = r'\textit{%s}' % mean_str_content[max_i]
|
42 |
+
std_str_content[min_i] = r'\textbf{%s}' % std_str_content[min_i]
|
43 |
+
std_str_content[max_i] = r'\textit{%s}' % std_str_content[max_i]
|
44 |
+
else:
|
45 |
+
mean_str_content[max_i] = r'\textbf{%s}' % mean_str_content[max_i]
|
46 |
+
mean_str_content[min_i] = r'\textit{%s}' % mean_str_content[min_i]
|
47 |
+
std_str_content[max_i] = r'\textbf{%s}' % std_str_content[max_i]
|
48 |
+
std_str_content[min_i] = r'\textit{%s}' % std_str_content[min_i]
|
49 |
+
print(' &', ' & '.join(mean_str_content), r'\\')
|
50 |
+
print(' & &', ' & '.join(std_str_content), r'\\')
|
51 |
+
pass
|
52 |
+
|
53 |
+
def _print_block(_task):
|
54 |
+
fds = [
|
55 |
+
f'sac/{_task}', f'egsac/{_task}', f'asyncsac/{_task}',
|
56 |
+
f'pmoe/{_task}', f'dvd/{_task}', f'sunrise/{_task}',
|
57 |
+
f'varpm-{_task}/l0.0_m5', f'varpm-{_task}/l0.1_m5', f'varpm-{_task}/l0.2_m5',
|
58 |
+
f'varpm-{_task}/l0.3_m5', f'varpm-{_task}/l0.4_m5', f'varpm-{_task}/l0.5_m5'
|
59 |
+
]
|
60 |
+
rewards, divs = [], []
|
61 |
+
for fd in fds:
|
62 |
+
rewards.append([])
|
63 |
+
divs.append([])
|
64 |
+
# print(getpath())
|
65 |
+
for path in glob.glob(getpath('test_data', fd, '**', 'performance.csv'), recursive=True):
|
66 |
+
reward, div = load_dict_json(path, 'reward', 'diversity')
|
67 |
+
rewards[-1].append(reward)
|
68 |
+
divs[-1].append(div)
|
69 |
+
rewards = np.array(rewards)
|
70 |
+
divs = np.array(divs)
|
71 |
+
|
72 |
+
print(' & \\multirow{2}{*}{Reward}')
|
73 |
+
_print_line(rewards)
|
74 |
+
print(' \\cline{2-14}')
|
75 |
+
print(' & \\multirow{2}{*}{Diversity}')
|
76 |
+
_print_line(divs)
|
77 |
+
print(' \\cline{2-14}')
|
78 |
+
print(' & \\multirow{2}{*}{G-mean}')
|
79 |
+
gmean = np.sqrt(rewards * divs)
|
80 |
+
_print_line(gmean)
|
81 |
+
|
82 |
+
print(' \\cline{2-14}')
|
83 |
+
print(' & \\multirow{2}{*}{N-rank}')
|
84 |
+
r_rank = np.zeros_like(rewards.flatten())
|
85 |
+
r_rank[np.argsort(-rewards.flatten())] = np.linspace(1, len(r_rank), len(r_rank))
|
86 |
+
|
87 |
+
d_rank = np.zeros_like(divs.flatten())
|
88 |
+
d_rank[np.argsort(-divs.flatten())] = np.linspace(1, len(r_rank), len(r_rank))
|
89 |
+
n_rank = (r_rank.reshape([12, 5]) + d_rank.reshape([12, 5])) / (2 * 5)
|
90 |
+
_print_line(n_rank, True)
|
91 |
+
|
92 |
+
print(' \\multirow{8}{*}{MarioPuzzle}')
|
93 |
+
_print_block('fhp')
|
94 |
+
print(' \\midrule')
|
95 |
+
print(' \\multirow{8}{*}{MultiFacet}')
|
96 |
+
_print_block('lgp')
|
97 |
+
pass
|
98 |
+
|
99 |
+
def print_compare_tab_nonrl():
|
100 |
+
# rand_lgp, rand_fhp, rand_divs = load_dict_json(
|
101 |
+
# 'test_data/rand_policy/performance.csv', 'lgp', 'fhp', 'diversity'
|
102 |
+
# )
|
103 |
+
# rand_performance = {'lgp': rand_lgp, 'fhp': rand_fhp, 'diversity': rand_divs}
|
104 |
+
|
105 |
+
def _print_line(_data, minimise=False):
|
106 |
+
means = _data.mean(axis=-1)
|
107 |
+
stds = _data.std(axis=-1)
|
108 |
+
max_i, min_i = np.argmax(means), np.argmin(means)
|
109 |
+
mean_str_content = [*map(lambda x: '%.4g' % x, _data.mean(axis=-1))]
|
110 |
+
std_str_content = [*map(lambda x: '$\pm$%.3g' % x, _data.std(axis=-1))]
|
111 |
+
if minimise:
|
112 |
+
mean_str_content[min_i] = r'\textbf{%s}' % mean_str_content[min_i]
|
113 |
+
mean_str_content[max_i] = r'\textit{%s}' % mean_str_content[max_i]
|
114 |
+
std_str_content[min_i] = r'\textbf{%s}' % std_str_content[min_i]
|
115 |
+
std_str_content[max_i] = r'\textit{%s}' % std_str_content[max_i]
|
116 |
+
else:
|
117 |
+
mean_str_content[max_i] = r'\textbf{%s}' % mean_str_content[max_i]
|
118 |
+
mean_str_content[min_i] = r'\textit{%s}' % mean_str_content[min_i]
|
119 |
+
std_str_content[max_i] = r'\textbf{%s}' % std_str_content[max_i]
|
120 |
+
std_str_content[min_i] = r'\textit{%s}' % std_str_content[min_i]
|
121 |
+
print(' &', ' & '.join(mean_str_content), r'\\')
|
122 |
+
print(' & &', ' & '.join(std_str_content), r'\\')
|
123 |
+
pass
|
124 |
+
|
125 |
+
def _print_block(_task):
|
126 |
+
fds = [
|
127 |
+
f'GAN-{_task}', f'DDPM-{_task}',
|
128 |
+
f'varpm-{_task}/l0.0_m5', f'varpm-{_task}/l0.1_m5', f'varpm-{_task}/l0.2_m5',
|
129 |
+
f'varpm-{_task}/l0.3_m5', f'varpm-{_task}/l0.4_m5', f'varpm-{_task}/l0.5_m5'
|
130 |
+
]
|
131 |
+
rewards, divs = [], []
|
132 |
+
for fd in fds:
|
133 |
+
rewards.append([])
|
134 |
+
divs.append([])
|
135 |
+
# print(getpath())
|
136 |
+
for path in glob.glob(getpath('test_data', fd, '**', 'performance.csv'), recursive=True):
|
137 |
+
reward, div = load_dict_json(path, 'reward', 'diversity')
|
138 |
+
rewards[-1].append(reward)
|
139 |
+
divs[-1].append(div)
|
140 |
+
rewards = np.array(rewards)
|
141 |
+
divs = np.array(divs)
|
142 |
+
|
143 |
+
print(' & \\multirow{2}{*}{Reward}')
|
144 |
+
_print_line(rewards)
|
145 |
+
print(' \\cline{2-10}')
|
146 |
+
print(' & \\multirow{2}{*}{Diversity}')
|
147 |
+
_print_line(divs)
|
148 |
+
print(' \\cline{2-10}')
|
149 |
+
# print(' & \\multirow{2}{*}{G-mean}')
|
150 |
+
# gmean = np.sqrt(rewards * divs)
|
151 |
+
# _print_line(gmean)
|
152 |
+
#
|
153 |
+
# print(' \\cline{2-10}')
|
154 |
+
# print(' & \\multirow{2}{*}{N-rank}')
|
155 |
+
# r_rank = np.zeros_like(rewards.flatten())
|
156 |
+
# r_rank[np.argsort(-rewards.flatten())] = np.linspace(1, len(r_rank), len(r_rank))
|
157 |
+
#
|
158 |
+
# d_rank = np.zeros_like(divs.flatten())
|
159 |
+
# d_rank[np.argsort(-divs.flatten())] = np.linspace(1, len(r_rank), len(r_rank))
|
160 |
+
# n_rank = (r_rank.reshape([8, 5]) + d_rank.reshape([8, 5])) / (2 * 5)
|
161 |
+
# _print_line(n_rank, True)
|
162 |
+
|
163 |
+
print(' \\multirow{4}{*}{MarioPuzzle}')
|
164 |
+
_print_block('fhp')
|
165 |
+
print(' \\midrule')
|
166 |
+
print(' \\multirow{4}{*}{MultiFacet}')
|
167 |
+
_print_block('lgp')
|
168 |
+
pass
|
169 |
+
|
170 |
+
def plot_cmp_learning_curves(task, save_path='', title=''):
|
171 |
+
plt.style.use('seaborn')
|
172 |
+
colors = [plt.plot([0, 1], [-1000, -1000])[0].get_color() for _ in range(6)]
|
173 |
+
plt.cla()
|
174 |
+
plt.style.use('default')
|
175 |
+
|
176 |
+
# colors = ('#5D2CAB', '#005BD4', '#007CE4', '#0097DD', '#00ADC4', '#00C1A5')
|
177 |
+
def _get_algo_data(fd):
|
178 |
+
res = []
|
179 |
+
for i in range(1, 6):
|
180 |
+
path = getpath(fd, f't{i}', 'step_tests.csv')
|
181 |
+
try:
|
182 |
+
data = pds.read_csv(path)
|
183 |
+
trajectory = [
|
184 |
+
[float(item['step']), float(item['r-avg']), float(item['diversity'])]
|
185 |
+
for _, item in data.iterrows()
|
186 |
+
]
|
187 |
+
trajectory.sort(key=lambda x: x[0])
|
188 |
+
res.append(trajectory)
|
189 |
+
if len(trajectory) != 26:
|
190 |
+
print('Not complete (%d)/26:' % len(trajectory), path)
|
191 |
+
except FileNotFoundError:
|
192 |
+
print(path)
|
193 |
+
res = np.array(res)
|
194 |
+
# rdsum = res[:, :, 1] + res[:, :, 2]
|
195 |
+
gmean = np.sqrt(res[:, :, 1] * res[:, :, 2])
|
196 |
+
steps = res[0, :, 0]
|
197 |
+
# r_avgs = np.mean(res[:, :, 1], axis=0)
|
198 |
+
# r_stds = np.std(res[:, :, 1], axis=0)
|
199 |
+
# divs = np.mean(res[:, :, 2], axis=0)
|
200 |
+
# div_std = np.std(res[:, :, 2], axis=0)
|
201 |
+
_performances = {
|
202 |
+
'reward': (np.mean(res[:, :, 1], axis=0), np.std(res[:, :, 1], axis=0)),
|
203 |
+
'diversity': (np.mean(res[:, :, 2], axis=0), np.std(res[:, :, 2], axis=0)),
|
204 |
+
# 'rdsum': (np.mean(rdsum, axis=0), np.std(rdsum, axis=0)),
|
205 |
+
'gmean': (np.mean(gmean, axis=0), np.std(gmean, axis=0)),
|
206 |
+
}
|
207 |
+
# print(_performances['gmean'])
|
208 |
+
return steps, _performances
|
209 |
+
|
210 |
+
def _plot_criterion(_ax, _criterion):
|
211 |
+
i, j, k = 0, 0, 0
|
212 |
+
for algo, (steps, _performances) in performances.items():
|
213 |
+
avgs, stds = _performances[_criterion]
|
214 |
+
if '\lambda' in algo:
|
215 |
+
ls = '-'
|
216 |
+
_c = colors[i]
|
217 |
+
i += 1
|
218 |
+
elif algo in {'SAC', 'EGSAC', 'ASAC'}:
|
219 |
+
ls = ':'
|
220 |
+
_c = colors[j]
|
221 |
+
j += 1
|
222 |
+
else:
|
223 |
+
ls = '--'
|
224 |
+
_c = colors[j]
|
225 |
+
j += 1
|
226 |
+
_ax.plot(steps, avgs, color=_c, label=algo, ls=ls)
|
227 |
+
_ax.fill_between(steps, avgs - stds, avgs + stds, color=_c, alpha=0.15)
|
228 |
+
_ax.grid(False)
|
229 |
+
# plt.plot(steps, avgs, label=algo)
|
230 |
+
# plt.plot(_performances, label=algo)
|
231 |
+
pass
|
232 |
+
_ax.set_xlabel('Time step')
|
233 |
+
|
234 |
+
fig, ax = plt.subplots(1, 3, figsize=(9.6, 3.2), dpi=250, width_ratios=[1, 1, 1])
|
235 |
+
# fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(5, 4), dpi=256)
|
236 |
+
# fig, ax1 = plt.subplots(1, 1, figsize=(8, 3), dpi=256)
|
237 |
+
# ax2 = ax1.twinx()
|
238 |
+
# fig = plt.plot(figsize=(4, 3), dpi=256)
|
239 |
+
performances = {
|
240 |
+
'SUNRISE': _get_algo_data(f'test_data/sunrise/{task}'),
|
241 |
+
'$\lambda$=0.0': _get_algo_data(f'test_data/varpm-{task}/l0.0_m5'),
|
242 |
+
'DvD': _get_algo_data(f'test_data/dvd/{task}'),
|
243 |
+
'$\lambda$=0.1': _get_algo_data(f'test_data/varpm-{task}/l0.1_m5'),
|
244 |
+
'PMOE': _get_algo_data(f'test_data/pmoe/{task}'),
|
245 |
+
'$\lambda$=0.2': _get_algo_data(f'test_data/varpm-{task}/l0.2_m5'),
|
246 |
+
'SAC': _get_algo_data(f'test_data/sac/{task}'),
|
247 |
+
'$\lambda$=0.3': _get_algo_data(f'test_data/varpm-{task}/l0.3_m5'),
|
248 |
+
'EGSAC': _get_algo_data(f'test_data/egsac/{task}'),
|
249 |
+
'$\lambda$=0.4': _get_algo_data(f'test_data/varpm-{task}/l0.4_m5'),
|
250 |
+
'ASAC': _get_algo_data(f'test_data/asyncsac/{task}'),
|
251 |
+
'$\lambda$=0.5': _get_algo_data(f'test_data/varpm-{task}/l0.5_m5'),
|
252 |
+
}
|
253 |
+
# _plot_algo(*_get_algo_data(glob.glob(getpath('test_data/SAC', '**', 'step_tests.csv'))), 'SAC')
|
254 |
+
# _plot_algo(*_get_algo_data(glob.glob(getpath('test_data/EGSAC', '**', 'step_tests.csv'))), 'EGSAC')
|
255 |
+
# _plot_algo(*_get_algo_data(glob.glob(getpath('test_data/AsyncSAC', '**', 'step_tests.csv'))), 'AsyncSAC')
|
256 |
+
# _plot_algo(*_get_algo_data(glob.glob(getpath('test_data/SUNRISE', '**', 'step_tests.csv'))), 'SUNRISE')
|
257 |
+
# _plot_algo(*_get_algo_data(glob.glob(getpath('test_data/DvD-ES', '**', 'step_tests.csv'))), 'DvD-ES')
|
258 |
+
# _plot_algo(*_get_algo_data(glob.glob(getpath('test_data/lbd-m-crosstest/l0.04_m5', '**', 'step_tests.csv'))), 'NCESAC')
|
259 |
+
|
260 |
+
|
261 |
+
|
262 |
+
_plot_criterion(ax[0], 'reward')
|
263 |
+
_plot_criterion(ax[1], 'diversity')
|
264 |
+
# _plot_criterion(ax[2], 'rdsum')
|
265 |
+
_plot_criterion(ax[2], 'gmean')
|
266 |
+
# ax[0].set_title(f'{title} reward')
|
267 |
+
ax[0].set_title(f'Cumulative Reward')
|
268 |
+
ax[1].set_title('Diversity Score')
|
269 |
+
# ax[2].set_title('Summation')
|
270 |
+
ax[2].set_title('G-mean')
|
271 |
+
# plt.title(title)
|
272 |
+
|
273 |
+
lines, labels = fig.axes[-1].get_legend_handles_labels()
|
274 |
+
fig.suptitle(title, fontsize=14)
|
275 |
+
plt.tight_layout(pad=0.5)
|
276 |
+
if save_path:
|
277 |
+
plt.savefig(getpath(save_path))
|
278 |
+
else:
|
279 |
+
plt.show()
|
280 |
+
|
281 |
+
plt.cla()
|
282 |
+
plt.figure(figsize=(9.6, 2.4), dpi=250)
|
283 |
+
plt.grid(False)
|
284 |
+
plt.axis('off')
|
285 |
+
plt.yticks([1.0])
|
286 |
+
plt.legend(
|
287 |
+
lines, labels, loc='lower center', ncol=6, edgecolor='white', fontsize=15,
|
288 |
+
columnspacing=0.8, borderpad=0.16, labelspacing=0.2, handlelength=2.4, handletextpad=0.3
|
289 |
+
)
|
290 |
+
plt.tight_layout(pad=0.5)
|
291 |
+
plt.show()
|
292 |
+
pass
|
293 |
+
|
294 |
+
def plot_crosstest_scatters(rfunc, xrange=None, yrange=None, title=''):
|
295 |
+
def get_pareto():
|
296 |
+
all_points = list(chain(*scatter_groups.values())) + cmp_points
|
297 |
+
res = []
|
298 |
+
for p in all_points:
|
299 |
+
non_dominated = True
|
300 |
+
for q in all_points:
|
301 |
+
if q[0] >= p[0] and q[1] >= p[1] and (q[0] > p[0] or q[1] > p[1]):
|
302 |
+
non_dominated = False
|
303 |
+
break
|
304 |
+
if non_dominated:
|
305 |
+
res.append(p)
|
306 |
+
res.sort(key=lambda item:item[0])
|
307 |
+
return np.array(res)
|
308 |
+
def _hex_color(_c):
|
309 |
+
return
|
310 |
+
scatter_groups = {}
|
311 |
+
all_lbd = set()
|
312 |
+
# Initialise
|
313 |
+
plt.style.use('seaborn-v0_8-dark-palette')
|
314 |
+
# plt.figure(figsize=(4, 4), dpi=256)
|
315 |
+
plt.figure(figsize=(2.5, 2.5), dpi=256)
|
316 |
+
plt.axes().set_axisbelow(True)
|
317 |
+
|
318 |
+
# Competitors' performances
|
319 |
+
cmp_folders = ['asyncsac', 'egsac', 'sac', 'sunrise', 'dvd', 'pmoe']
|
320 |
+
cmp_names = ['ASAC', 'EGSAC', 'SAC', 'SUNRISE', 'DvD', 'PMOE']
|
321 |
+
cmp_labels = ['A', 'E', 'S', 'R', 'D', 'M']
|
322 |
+
cmp_markers = ['2', 'x', '+', 'o', '*', 'D']
|
323 |
+
cmp_sizes = [42, 20, 32, 16, 24, 10, 10]
|
324 |
+
cmp_points = []
|
325 |
+
for name, folder, label, mk, s in zip(cmp_names, cmp_folders, cmp_labels, cmp_markers, cmp_sizes):
|
326 |
+
path_fmt = getpath('test_data', folder, rfunc, '*', 'performance.csv')
|
327 |
+
# print(path_fmt)
|
328 |
+
xs, ys = [], []
|
329 |
+
for path in glob.glob(path_fmt, recursive=True):
|
330 |
+
# print(path)
|
331 |
+
try:
|
332 |
+
x, y = load_dict_json(path, 'reward', 'diversity')
|
333 |
+
xs.append(x)
|
334 |
+
ys.append(y)
|
335 |
+
cmp_points.append([x, y])
|
336 |
+
# plt.text(x, y, label, size=7, weight='bold', va='center', ha='center', color='#202020')
|
337 |
+
except FileNotFoundError:
|
338 |
+
print(path)
|
339 |
+
if label in {'A', 'E', 'S'}:
|
340 |
+
plt.scatter(xs, ys, marker=mk, zorder=2, s=s, label=name, color='#202020')
|
341 |
+
else:
|
342 |
+
plt.scatter(
|
343 |
+
xs, ys, marker=mk, zorder=2, s=s, label=name, color=[0., 0., 0., 0.],
|
344 |
+
edgecolors='#202020', linewidths=1
|
345 |
+
)
|
346 |
+
# NCESAC performances
|
347 |
+
for path in glob.glob(getpath('test_data', f'varpm-{rfunc}', '**', 'performance.csv'), recursive=True):
|
348 |
+
try:
|
349 |
+
x, y = load_dict_json(path, 'reward', 'diversity')
|
350 |
+
key = path.split('\\')[-3]
|
351 |
+
_, mtxt = key.split('_')
|
352 |
+
ltxt, _ = key.split('_')
|
353 |
+
lbd = float(ltxt[1:])
|
354 |
+
# if mtxt in {'m2', 'm3', 'm4'}:
|
355 |
+
# continue
|
356 |
+
all_lbd.add(lbd)
|
357 |
+
if key not in scatter_groups.keys():
|
358 |
+
scatter_groups[key] = []
|
359 |
+
scatter_groups[key].append([x, y])
|
360 |
+
except Exception as e:
|
361 |
+
print(path)
|
362 |
+
print(e)
|
363 |
+
|
364 |
+
palette = plt.get_cmap('seismic')
|
365 |
+
color_x = [0.2, 0.33, 0.4, 0.61, 0.67, 0.79]
|
366 |
+
colors = {lbd: matplotlib.colors.to_hex(c) for c, lbd in zip(palette(color_x), sorted(all_lbd))}
|
367 |
+
colors = {0.0: '#150080', 0.1: '#066598', 0.2: '#01E499', 0.3: '#9FD40C', 0.4: '#F3B020', 0.5: '#FA0000'}
|
368 |
+
for lbd in sorted(all_lbd): plt.plot([-20], [-20], label=f'$\\lambda={lbd:.1f}$', lw=6, c=colors[lbd])
|
369 |
+
markers = {2: 'o', 3: '^', 4: 'D', 5: 'p', 6: 'h'}
|
370 |
+
msizes = {2: 25, 3: 25, 4: 16, 5: 28, 6: 32}
|
371 |
+
for key, group in scatter_groups.items():
|
372 |
+
ltxt, mtxt = key.split('_')
|
373 |
+
l = float(ltxt[1:])
|
374 |
+
m = int(mtxt[1:])
|
375 |
+
arr = np.array(group)
|
376 |
+
plt.scatter(
|
377 |
+
arr[:, 0], arr[:, 1], marker=markers[m], s=msizes[m], color=[0., 0., 0., 0.], zorder=2,
|
378 |
+
edgecolors=colors[l], linewidths=1
|
379 |
+
)
|
380 |
+
|
381 |
+
plt.xlim(xrange)
|
382 |
+
plt.ylim(yrange)
|
383 |
+
# plt.xlabel('Task Reward')
|
384 |
+
# plt.ylabel('Diversity')
|
385 |
+
# plt.legend(ncol=2)
|
386 |
+
# plt.legend(
|
387 |
+
# ncol=2, loc='lower left', columnspacing=1.2, borderpad=0.0,
|
388 |
+
# handlelength=1, handletextpad=0.5, framealpha=0.
|
389 |
+
# )
|
390 |
+
pareto = get_pareto()
|
391 |
+
plt.plot(
|
392 |
+
pareto[:, 0], pareto[:, 1], color='black', alpha=0.18, lw=6, zorder=3,
|
393 |
+
solid_joinstyle='round', solid_capstyle='round'
|
394 |
+
)
|
395 |
+
# plt.plot([88, 98, 98, 88, 88], [35, 35, 0.2, 0.2, 35], color='black', alpha=0.3, lw=1.5)
|
396 |
+
# plt.xticks(fontsize=16)
|
397 |
+
# plt.yticks(fontsize=16)
|
398 |
+
# plt.xticks([(1+space) * (m-mlow) + 0.5 for m in ms], [f'm={m}' for m in ms])
|
399 |
+
plt.title(title)
|
400 |
+
plt.grid()
|
401 |
+
plt.tight_layout(pad=0.4)
|
402 |
+
plt.show()
|
403 |
+
|
404 |
+
def plot_varpm_heat(task, name):
|
405 |
+
def _get_score(m, l):
|
406 |
+
fd = getpath('test_data', f'varpm-{task}', f'l{l}_m{m}')
|
407 |
+
rewards, divs = [], []
|
408 |
+
for i in range(5):
|
409 |
+
reward, div = load_dict_json(f'{fd}/t{i+1}/performance.csv', 'reward', 'diversity')
|
410 |
+
rewards.append(reward)
|
411 |
+
divs.append(div)
|
412 |
+
gmean = [sqrt(r * d) for r, d in zip(rewards, divs)]
|
413 |
+
return np.mean(rewards), np.std(rewards), \
|
414 |
+
np.mean(divs), np.std(divs), \
|
415 |
+
np.mean(gmean), np.std(gmean)
|
416 |
+
|
417 |
+
def _plot_map(avg_map, std_map, criterion):
|
418 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(9, 3), dpi=256, width_ratios=(1, 1))
|
419 |
+
heat1 = ax1.imshow(avg_map, cmap='spring')
|
420 |
+
heat2 = ax2.imshow(std_map, cmap='spring')
|
421 |
+
ax1.set_xlim([-0.5, 5.5])
|
422 |
+
ax1.set_xticks([0, 1, 2, 3, 4, 5], ['$\lambda$=0.0', '$\lambda$=0.1', '$\lambda$=0.2', '$\lambda$=0.3', '$\lambda$=0.4', '$\lambda$=0.5'])
|
423 |
+
ax1.set_ylim([-0.5, 3.5])
|
424 |
+
ax1.set_yticks([0, 1, 2, 3], ['m=5', 'm=4', 'm=3', 'm=2'])
|
425 |
+
ax1.set_title('Average')
|
426 |
+
for x, y in product([0, 1, 2, 3, 4, 5], [0, 1, 2, 3]):
|
427 |
+
v = avg_map[y, x]
|
428 |
+
s = '%.4f' % v
|
429 |
+
if v >= 1000: s = s[:4]
|
430 |
+
elif v >= 1: s = s[:5]
|
431 |
+
else: s = s[1:6]
|
432 |
+
ax1.text(x, y, s, va='center', ha='center')
|
433 |
+
plt.colorbar(heat1, ax=ax1, shrink=0.9)
|
434 |
+
ax2.set_xlim([-0.5, 5.5])
|
435 |
+
ax2.set_xticks([0, 1, 2, 3, 4, 5], ['$\lambda$=0.0', '$\lambda$=0.1', '$\lambda$=0.2', '$\lambda$=0.3', '$\lambda$=0.4', '$\lambda$=0.5'])
|
436 |
+
ax2.set_ylim([-0.5, 3.5])
|
437 |
+
ax2.set_yticks([0, 1, 2, 3], ['m=5', 'm=4', 'm=3', 'm=2'])
|
438 |
+
for x, y in product([0, 1, 2, 3, 4, 5], [0, 1, 2, 3]):
|
439 |
+
v = std_map[y, x]
|
440 |
+
s = '%.4f' % v
|
441 |
+
if v >= 1000: s = s[:4]
|
442 |
+
elif v >= 1: s = s[:5]
|
443 |
+
else: s = s[1:6]
|
444 |
+
ax2.text(x, y, s, va='center', ha='center')
|
445 |
+
ax2.set_title('Standard Deviation')
|
446 |
+
plt.colorbar(heat2, ax=ax2, shrink=0.9)
|
447 |
+
|
448 |
+
fig.suptitle(f'{name}: {criterion}', fontsize=14)
|
449 |
+
plt.tight_layout()
|
450 |
+
# plt.show()
|
451 |
+
plt.savefig(getpath(f'results/heat/{name}-{criterion}.png'))
|
452 |
+
|
453 |
+
r_mean_map, r_std_map, d_mean_map, d_std_map, g_mean_map, g_std_map \
|
454 |
+
= (np.zeros([4, 6], dtype=float) for _ in range(6))
|
455 |
+
ms = [2, 3, 4, 5]
|
456 |
+
ls = ['0.0', '0.1', '0.2', '0.3', '0.4', '0.5']
|
457 |
+
for i, j in product(range(4), range(6)):
|
458 |
+
r_mean, r_std, d_mean, d_std, g_mean, g_std = _get_score(ms[i], ls[j])
|
459 |
+
r_mean_map[i, j] = r_mean
|
460 |
+
r_std_map[i, j] = r_std
|
461 |
+
d_mean_map[i, j] = d_mean
|
462 |
+
d_std_map[i, j] = d_std
|
463 |
+
g_mean_map[i, j] = g_mean
|
464 |
+
g_std_map[i, j] = g_std
|
465 |
+
|
466 |
+
_plot_map(r_mean_map, r_std_map, 'Reward')
|
467 |
+
_plot_map(d_mean_map, d_std_map, 'Diversity')
|
468 |
+
_plot_map(g_mean_map, g_std_map,'G-mean')
|
469 |
+
# _plot_map(g_mean_map, g_std_map,'G-mean')
|
470 |
+
|
471 |
+
def vis_samples():
|
472 |
+
# for l, m in product(['0.0', '0.1', '0.2', '0.3', '0.4', '0.5'], [2, 3, 4, 5]):
|
473 |
+
# for i in range(1, 6):
|
474 |
+
# lvls = load_batch(f'{PRJROOT}/test_data/varpm-fhp/l{l}_m{m}/t{i}/samples.lvls')
|
475 |
+
# imgs = [lvl.to_img(save_path=None) for lvl in lvls[:10]]
|
476 |
+
# make_img_sheet(imgs, 1, save_path=f'{PRJROOT}/test_data/varpm-fhp/l{l}_m{m}/t{i}/samples.png')
|
477 |
+
# for algo in ['sac', 'egsac', 'asyncsac', 'dvd', 'sunrise', 'pmoe']:
|
478 |
+
# for i in range(1, 6):
|
479 |
+
# lvls = load_batch(f'{PRJROOT}/test_data/{algo}/fhp/t{i}/samples.lvls')
|
480 |
+
# imgs = [lvl.to_img(save_path=None) for lvl in lvls[:10]]
|
481 |
+
# make_img_sheet(imgs, 1, save_path=f'{PRJROOT}/test_data/{algo}/fhp/t{i}/samples.png')
|
482 |
+
for i in range(1, 6):
|
483 |
+
lvls = load_batch(f'{PRJROOT}/test_data/DDPM-fhp/t{i}/samples.lvls')
|
484 |
+
imgs = [lvl.to_img(save_path=None) for lvl in lvls[:10]]
|
485 |
+
make_img_sheet(imgs, 1, save_path=f'{PRJROOT}/test_data/DDPM-fhp/t{i}/samples.png')
|
486 |
+
pass
|
487 |
+
pass
|
488 |
+
|
489 |
+
def make_tsne(task, title, n=500, save_path=None):
|
490 |
+
if not os.path.exists(getpath('test_data', f'samples_dist-{task}_{n}.npy')):
|
491 |
+
samples = []
|
492 |
+
for algo in ['dvd', 'egsac', 'pmoe', 'sunrise', 'asyncsac', 'sac']:
|
493 |
+
for t in range(5):
|
494 |
+
lvls = load_batch(getpath('test_data', algo, task, f't{t+1}', 'samples.lvls'))
|
495 |
+
samples += lvls[:n]
|
496 |
+
for l in ['0.0', '0.1', '0.2', '0.3', '0.4', '0.5']:
|
497 |
+
for t in range(5):
|
498 |
+
lvls = load_batch(getpath('test_data', f'varpm-{task}', f'l{l}_m5', f't{t+1}', 'samples.lvls'))
|
499 |
+
samples += lvls[:n]
|
500 |
+
distmat = []
|
501 |
+
for a in samples:
|
502 |
+
dist_list = []
|
503 |
+
for b in samples:
|
504 |
+
dist_list.append(hamming_dis(a, b))
|
505 |
+
distmat.append(dist_list)
|
506 |
+
distmat = np.array(distmat)
|
507 |
+
np.save(getpath('test_data', f'samples_dist-{task}_{n}.npy'), distmat)
|
508 |
+
|
509 |
+
labels = (
|
510 |
+
'$\lambda$=0.0', '$\lambda$=0.1', '$\lambda$=0.2', '$\lambda$=0.3', '$\lambda$=0.4',
|
511 |
+
'$\lambda$=0.5', 'DvD', 'EGSAC', 'PMOE', 'SUNRISE', 'ASAC', 'SAC'
|
512 |
+
)
|
513 |
+
tsne = TSNE(learning_rate='auto', n_components=2, metric='precomputed')
|
514 |
+
print(np.load(getpath('test_data', f'samples_dist-{task}_{n}.npy')).shape)
|
515 |
+
data = np.load(getpath('test_data', f'samples_dist-{task}_{n}.npy'))
|
516 |
+
embx = np.array(tsne.fit_transform(data))
|
517 |
+
|
518 |
+
plt.style.use('seaborn-dark-palette')
|
519 |
+
plt.figure(figsize=(5, 5), dpi=384)
|
520 |
+
colors = [plt.plot([-1000, -1100], [0, 0])[0].get_color() for _ in range(6)]
|
521 |
+
for i in range(6):
|
522 |
+
x, y = embx[i*n*5:(i+1)*n*5, 0], embx[i*n*5:(i+1)*n*5, 1]
|
523 |
+
plt.scatter(x, y, s=10, label=labels[i], marker='x', c=colors[i])
|
524 |
+
for i in range(6, 12):
|
525 |
+
x, y = embx[i*n*5:(i+1)*n*5, 0], embx[i*n*5:(i+1)*n*5, 1]
|
526 |
+
plt.scatter(x, y, s=8, linewidths=0, label=labels[i], c=colors[i-6])
|
527 |
+
# plt.scatter(embx[100:200, 0], embx[100:200, 1], c=colors[1], s=12, linewidths=0, label='Killer')
|
528 |
+
# plt.scatter(embx[200:, 0], embx[200:, 1], c=colors[2], s=12, linewidths=0, label='Collector')
|
529 |
+
# for i in range(4):
|
530 |
+
# plt.text(embx[i+100, 0], embx[i+100, 1], str(i+1))
|
531 |
+
# plt.text(embx[i+200, 0], embx[i+200, 1], str(i+1))
|
532 |
+
# pass
|
533 |
+
# for emb, lb, c in zip(embs, labels,colors):
|
534 |
+
# plt.scatter(emb[:,0], emb[:,1], c=c, label=lb, alpha=0.15, linewidths=0, s=7)
|
535 |
+
|
536 |
+
# xspan = 1.08 * max(abs(embx[:, 0].max()), abs(embx[:, 0].min()))
|
537 |
+
# yspan = 1.08 * max(abs(embx[:, 1].max()), abs(embx[:, 1].min()))
|
538 |
+
|
539 |
+
xrange = [1.05 * embx[:, 0].min(), 1.05 * embx[:, 0].max()]
|
540 |
+
yrange = [1.05 * embx[:, 1].min(), 1.25 * embx[:, 1].max()]
|
541 |
+
|
542 |
+
plt.xlim(xrange)
|
543 |
+
plt.ylim(yrange)
|
544 |
+
plt.xticks([])
|
545 |
+
plt.yticks([])
|
546 |
+
# plt.legend(ncol=6, handletextpad=0.02, labelspacing=0.05, columnspacing=0.16)
|
547 |
+
# plt.xticks([-xspan, -0.5 * xspan, 0, 0.5 * xspan, xspan], [''] * 5)
|
548 |
+
# plt.yticks([-yspan, -0.5 * yspan, 0, 0.6 * yspan, yspan], [''] * 5)
|
549 |
+
plt.title(title)
|
550 |
+
plt.legend(loc='upper center', ncol=6, fontsize=9, handlelength=.5, handletextpad=0.5, columnspacing=0.3, framealpha=0.)
|
551 |
+
plt.tight_layout(pad=0.2)
|
552 |
+
|
553 |
+
if save_path:
|
554 |
+
plt.savefig(getpath(save_path))
|
555 |
+
else:
|
556 |
+
plt.show()
|
557 |
+
|
558 |
+
def _prob_fmt(p, digitals=3, threshold=0.001):
|
559 |
+
fmt = '%.' + str(digitals) + 'f'
|
560 |
+
if p < threshold:
|
561 |
+
return '$\\approx 0$'
|
562 |
+
else:
|
563 |
+
txt = '$%s$' % ((fmt % p)[1:])
|
564 |
+
if txt == '$.000$':
|
565 |
+
txt = '$1.00$'
|
566 |
+
return txt
|
567 |
+
|
568 |
+
def _g_fmt(v, digitals=4):
|
569 |
+
fmt = '%.' + str(digitals) + 'g'
|
570 |
+
txt = (fmt % v)
|
571 |
+
lack = digitals - len(txt.replace('-', '').replace('.', ''))
|
572 |
+
if lack > 0 and '.' not in txt:
|
573 |
+
txt += '.'
|
574 |
+
return txt + '0' * lack
|
575 |
+
pass
|
576 |
+
|
577 |
+
def print_selection_prob(path, h=15, runs=2):
|
578 |
+
s0 = 0
|
579 |
+
model = torch.load(getpath(f'{path}/policy.pth'), map_location='cpu')
|
580 |
+
model.requires_grad_(False)
|
581 |
+
model.to('cpu')
|
582 |
+
n = 11
|
583 |
+
# n = load_cfgs(path, 'N')
|
584 |
+
# print(model.m)
|
585 |
+
|
586 |
+
init_vec = np.load(getpath('analysis/initial_seg.npy'))[s0]
|
587 |
+
decoder = get_decoder(device='cpu')
|
588 |
+
obs_buffer = RingQueue(n)
|
589 |
+
for r in range(runs):
|
590 |
+
for _ in range(h): obs_buffer.push(np.zeros([nz]))
|
591 |
+
obs_buffer.push(init_vec)
|
592 |
+
level_latvecs = [init_vec]
|
593 |
+
probs = np.zeros([model.m, h])
|
594 |
+
# probs = []
|
595 |
+
selects = []
|
596 |
+
for t in range(h):
|
597 |
+
# probs.append([])
|
598 |
+
obs = torch.tensor(np.concatenate(obs_buffer.to_list(), axis=-1), dtype=torch.float).view([1, -1])
|
599 |
+
muss, stdss, betas = model.get_intermediate(torch.tensor(obs))
|
600 |
+
i = torch.multinomial(betas.squeeze(), 1).item()
|
601 |
+
# print(i)
|
602 |
+
mu, std = muss[0][i], stdss[0][i]
|
603 |
+
action = Normal(mu, std).rsample([1]).squeeze().numpy()
|
604 |
+
# print(action)
|
605 |
+
# print(mu)
|
606 |
+
# print(std)
|
607 |
+
# print(action.numpy())
|
608 |
+
obs_buffer.push(action)
|
609 |
+
level_latvecs.append(action)
|
610 |
+
# i = torch.multinomial(betas.squeeze(), 1).item()
|
611 |
+
# print(i)
|
612 |
+
probs[:, t] = betas.squeeze().numpy()
|
613 |
+
selects.append(i)
|
614 |
+
pass
|
615 |
+
onehots = decoder(torch.tensor(level_latvecs).view(-1, nz, 1, 1))
|
616 |
+
segs = process_onehot(onehots)
|
617 |
+
lvl = lvlhcat(segs)
|
618 |
+
lvl.to_img(f'figures/gen_process/run{r}-01.png')
|
619 |
+
txts = [[_prob_fmt(p) for p in row] for row in probs]
|
620 |
+
for t, i in enumerate(selects):
|
621 |
+
txts[i][t] = r'$\boldsymbol{%s}$' % txts[i][t][1:-1]
|
622 |
+
for i, txt in enumerate(txts):
|
623 |
+
print(f' & $\\beta_{i+1}$ &', ' & '.join(txt), r'\\')
|
624 |
+
print(r'\midrule')
|
625 |
+
|
626 |
+
pass
|
627 |
+
|
628 |
+
def calc_selection_freqs(task, n):
|
629 |
+
def _count_one_init():
|
630 |
+
counts = np.zeros([model.m])
|
631 |
+
# init_vec = np.load(getpath('analysis/initial_seg.npy'))
|
632 |
+
obs_buffer = RingQueue(n)
|
633 |
+
for _ in range(runs):
|
634 |
+
for _ in range(h): obs_buffer.push(np.zeros([len(init_vecs), nz]))
|
635 |
+
obs_buffer.push(init_vecs)
|
636 |
+
# level_latvecs = [init_vec]
|
637 |
+
for _ in range(h):
|
638 |
+
obs = np.concatenate(obs_buffer.to_list(), axis=-1)
|
639 |
+
obs = torch.tensor(obs, device='cuda:0', dtype=torch.float)
|
640 |
+
muss, stdss, betas = model.get_intermediate(obs)
|
641 |
+
selects = torch.multinomial(betas.squeeze(), 1).squeeze()
|
642 |
+
mus = muss[[*range(len(init_vecs))], selects, :]
|
643 |
+
stds = stdss[[*range(len(init_vecs))], selects, :]
|
644 |
+
actions = Normal(mus, stds).rsample().squeeze().cpu().numpy()
|
645 |
+
obs_buffer.push(actions)
|
646 |
+
for i in selects:
|
647 |
+
counts[i] = counts[i] + 1
|
648 |
+
return counts
|
649 |
+
# onehots = decoder(torch.tensor(level_latvecs).view(-1, nz, 1, 1))
|
650 |
+
pass
|
651 |
+
pass
|
652 |
+
init_vecs = np.load(getpath('analysis/initial_seg.npy'))
|
653 |
+
freqs = [[] for _ in range(30)]
|
654 |
+
start_line = 0
|
655 |
+
for l in ('0.0', '0.1', '0.2', '0.3', '0.4', '0.5'):
|
656 |
+
print(r' \midrule')
|
657 |
+
for t, m in product(range(1, 6), (2, 3, 4, 5)):
|
658 |
+
path = getpath(f'test_data/varpm-{task}/l{l}_m{m}/t{t}')
|
659 |
+
model = torch.load(getpath(f'{path}/policy.pth'), map_location='cuda:0')
|
660 |
+
model.requires_grad_(False)
|
661 |
+
freq = np.zeros([m])
|
662 |
+
# n = load_cfgs(path, 'N')
|
663 |
+
runs, h = 100, 25
|
664 |
+
freq += _count_one_init()
|
665 |
+
freq /= (len(init_vecs) * runs * h)
|
666 |
+
freq = np.sort(freq)[::-1]
|
667 |
+
i = start_line + t - 1
|
668 |
+
freqs[i] += freq.tolist()
|
669 |
+
print(freqs[i])
|
670 |
+
start_line += 5
|
671 |
+
print(freqs)
|
672 |
+
with open(getpath(f'analysis/select_freqs-{task}.json'), 'w') as f:
|
673 |
+
json.dump(freqs, f)
|
674 |
+
|
675 |
+
def print_selection_freq():
|
676 |
+
# task, n = 'lgp', 5
|
677 |
+
task, n = 'fhp', 11
|
678 |
+
if not os.path.exists(getpath(f'analysis/select_freqs-{task}.json')):
|
679 |
+
calc_selection_freqs(task, n)
|
680 |
+
with open(getpath(f'analysis/select_freqs-{task}.json'), 'r') as f:
|
681 |
+
freqs = json.load(f)
|
682 |
+
lbds = ['0.0', '0.1', '0.2', '0.3', '0.4', '0.5']
|
683 |
+
for i, row_data in enumerate(freqs):
|
684 |
+
if i % 5 == 0:
|
685 |
+
print(r' \midrule')
|
686 |
+
print(r' \multirow{5}{*}{$%s$}' % lbds[i//5])
|
687 |
+
txt = ' & '.join(map(_prob_fmt, row_data))
|
688 |
+
print(f' & {i%5+1} &', txt, r'\\')
|
689 |
+
|
690 |
+
def print_individual_performances(task):
|
691 |
+
for m, l in product((2, 3, 4, 5), ('0.0', '0.1', '0.2', '0.3', '0.4', '0.5')):
|
692 |
+
values = []
|
693 |
+
if l == '0.0':
|
694 |
+
print(r' \midrule')
|
695 |
+
print(r' \multirow{6}{*}{%d}' % m)
|
696 |
+
for t in range(1, 6):
|
697 |
+
path = f'test_data/varpm-{task}/l{l}_m{m}/t{t}/performance.csv'
|
698 |
+
reward, diversity = load_dict_json(path, 'reward', 'diversity')
|
699 |
+
values.append([reward, diversity])
|
700 |
+
values.sort(key=lambda item: -item[0])
|
701 |
+
values = [*chain(*values)]
|
702 |
+
txts = [_g_fmt(v) for v in values]
|
703 |
+
print(' &', f'${l}$ & ', ' & '.join(txts), r'\\')
|
704 |
+
pass
|
705 |
+
|
706 |
+
if __name__ == '__main__':
|
707 |
+
# print_selection_prob('test_data/varpm-fhp/l0.5_m5/t5')
|
708 |
+
# print_selection_prob('test_data/varpm-fhp/l0.1_m5/t5')
|
709 |
+
# print_selection_freq()
|
710 |
+
# print_compare_tab_nonrl()
|
711 |
+
# print_individual_performances('fhp')
|
712 |
+
# print('\n\n')
|
713 |
+
# print_individual_performances('lgp')
|
714 |
+
|
715 |
+
# plot_cmp_learning_curves('fhp', save_path='results/learning_curves/fhp.png', title='MarioPuzzle')
|
716 |
+
# plot_cmp_learning_curves('lgp', save_path='results/learning_curves/lgp.png', title='MultiFacet')
|
717 |
+
|
718 |
+
# plot_crosstest_scatters('fhp', title='MarioPuzzle')
|
719 |
+
# plot_crosstest_scatters('lgp', title='MultiFacet')
|
720 |
+
# # plot_crosstest_scatters('fhp', yrange=(0, 2500), xrange=(20, 70), title='MarioPuzzle')
|
721 |
+
# plot_crosstest_scatters('lgp', yrange=(0, 1500), xrange=(20, 50), title='MultiFacet')
|
722 |
+
# plot_crosstest_scatters('lgp', yrange=(0, 800), xrange=(44, 48), title=' ')
|
723 |
+
|
724 |
+
|
725 |
+
# plot_varpm_heat('fhp', 'MarioPuzzle')
|
726 |
+
# plot_varpm_heat('lgp', 'MultiFacet')
|
727 |
+
|
728 |
+
vis_samples()
|
729 |
+
|
730 |
+
# make_tsne('fhp', 'MarioPuzzle', n=100)
|
731 |
+
# make_tsne('lgp', 'MultiFacet', n=100)
|
732 |
+
pass
|
733 |
+
|
pyproject.toml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[tool.poetry]
|
2 |
+
name = "ncerl"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = ""
|
5 |
+
authors = ["Ziqi Wang"]
|
6 |
+
readme = "README.md"
|
7 |
+
|
8 |
+
[tool.poetry.dependencies]
|
9 |
+
python = "^3.9"
|
10 |
+
JPype1 = "1.3.0"
|
11 |
+
dtw = "1.4.0"
|
12 |
+
torch = "1.8.1"
|
13 |
+
numpy = "^2.0.0"
|
14 |
+
pillow = "10.0.0"
|
15 |
+
matplotlib = "3.6.3"
|
16 |
+
pandas = "1.3.2"
|
17 |
+
|
18 |
+
|
19 |
+
[build-system]
|
20 |
+
requires = ["poetry-core"]
|
21 |
+
build-backend = "poetry.core.masonry.api"
|
requirements.txt
ADDED
Binary file (4.87 kB). View file
|
|
root.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
PRJROOT = os.path.dirname(os.path.realpath(__file__))
|
smb/Mario-AI-Framework.jar
ADDED
Binary file (206 kB). View file
|
|
smb/assets/#.png
ADDED
smb/assets/1.png
ADDED
smb/assets/2.png
ADDED
smb/assets/@.png
ADDED
smb/assets/B.png
ADDED
smb/assets/BSP.png
ADDED
smb/assets/CB1.png
ADDED
smb/assets/CB2.png
ADDED
smb/assets/L.png
ADDED
smb/assets/ML.png
ADDED
smb/assets/MM.png
ADDED
smb/assets/MR.png
ADDED
smb/assets/MS.png
ADDED
smb/assets/Q.png
ADDED
smb/assets/S.png
ADDED
smb/assets/TLP.png
ADDED
smb/assets/TRP.png
ADDED
smb/assets/TSP.png
ADDED
smb/assets/U.png
ADDED
smb/assets/X.png
ADDED
smb/assets/[.png
ADDED
smb/assets/].png
ADDED
smb/assets/chomper.png
ADDED
smb/assets/g.png
ADDED
smb/assets/k.png
ADDED
smb/assets/o.png
ADDED
smb/assets/r.png
ADDED
smb/assets/stalk.png
ADDED
smb/assets/wingk.png
ADDED
smb/assets/wingr.png
ADDED
smb/assets/y.png
ADDED
smb/img/README.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
This folder contain the game graphics. All the graphics file has been modified by hand from the original files that was grabbed from: https://www.spriters-resource.com/nes/supermariobros/. Except for the font.gif file as it is the same as file from the MarioAI framework.
|
smb/img/background.png
ADDED