Spaces:
Sleeping
Sleeping
baiyanlali-zhao
commited on
Commit
•
3582c8a
1
Parent(s):
8be1cb6
添加注释
Browse files- README.md +58 -16
- analysis/tests.py +0 -3
- app.py +1 -7
- generate_and_play.py +8 -5
- media/banner.png +0 -0
- models/example_policy/samples.lvls +122 -122
- models/example_policy/samples.png +0 -0
- plots.py +0 -733
- pyproject.toml +0 -21
- requirements.txt +0 -0
- src/drl/egsac/train_egsac.py +2 -2
- src/drl/sunrise/train_sunrise.py +3 -8
- src/drl/train_async.py +15 -0
- src/drl/train_sinproc.py +13 -0
- src/env/environments.py +2 -5
- src/env/rfunc.py +6 -5
- src/gan/adversarial_train.py +0 -21
- src/gan/gankits.py +2 -2
- src/gan/gans.py +1 -1
- src/olgen/olg_policy.py +7 -44
- src/smb/asyncsimlt.py +4 -9
- src/smb/proxy.py +0 -8
- src/utils/img.py +0 -1
- test_ddpm.py +1 -81
- test_gen_log.py +0 -15
- test_gen_samples.py +0 -24
- tests.py +0 -140
- train.py +2 -0
README.md
CHANGED
@@ -8,26 +8,39 @@ python_version: 3.9
|
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
|
|
|
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
-
|
|
|
|
|
|
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
* torch 1.8.2+cu111
|
22 |
-
* numpy 1.20.3
|
23 |
-
* gym 0.21.0
|
24 |
-
* scipy 1.7.2
|
25 |
-
* Pillow 10.0.0
|
26 |
-
* matplotlib 3.6.3
|
27 |
-
* pandas 1.3.2
|
28 |
-
* sklearn 1.0.1
|
29 |
|
30 |
-
|
31 |
|
32 |
All training are launched by running `train.py` with option and arguments. For example, execute `python train.py ncesac --lbd 0.3 --m 5` will train NCERL with hyperparameters set as $\lambda = 0.3, m=5$.
|
33 |
Plot script is `plots.py`
|
@@ -36,9 +49,38 @@ All training are launched by running `train.py` with option and arguments. For e
|
|
36 |
* `python train.py sac`: to train a standard SAC as the policy for online game level generation
|
37 |
* `python train.py asyncsac`: to train a SAC with an asynchronous evaluation environment as the policy for online game level generation
|
38 |
* `python train.py ncesac`: to train an NCERL based on SAC as the policy for online game level generation
|
39 |
-
* `python train.py egsac`: to train an episodic generative SAC (see paper [*The fun facets of Mario: Multifaceted experience-driven PCG via reinforcement learning*](https://dl.acm.org/doi/abs/10.1145/3555858.3563282
|
40 |
* `python train.py pmoe`: to train an episodic generative SAC (see paper [*Probabilistic Mixture-of-Experts for Efficient Deep Reinforcement Learning*](https://arxiv.org/abs/2104.09122)) as the policy for online game level generation
|
41 |
* `python train.py sunrise`: to train a SUNRISE (see paper [*SUNRISE: A Simple Unified Framework for Ensemble Learning in Deep Reinforcement Learning*](https://proceedings.mlr.press/v139/lee21g.html)) as the policy for online game level generation
|
42 |
* `python train.py dvd`: to train a DvD-SAC (see paper [*Effective Diversity in Population Based Reinforcement Learning*](https://proceedings.neurips.cc/paper_files/paper/2020/hash/d1dc3a8270a6f9394f88847d7f0050cf-Abstract.html)) as the policy for online game level generation
|
43 |
|
44 |
For the training arguments, please refer to the help `python train.py [option] --help`
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
+
![alt text](./media/banner.png)
|
12 |
+
# Negatively Correlated Ensemble RL
|
13 |
|
14 |
+
## 环境安装
|
15 |
+
创建conda环境
|
16 |
+
```bash
|
17 |
+
conda create -n ncerl python=3.9
|
18 |
+
```
|
19 |
+
安装环境依赖
|
20 |
+
```bash
|
21 |
+
pip install -r requirements.txt
|
22 |
+
```
|
23 |
+
注:该程序不需要您使用任何显卡,但是需要安装pytorch。如果您的显卡支持cuda,那么请安装cuda版本,否则安装cpu版本。使用cuda版本可以提高推理速度。
|
24 |
|
25 |
+
切换conda环境
|
26 |
+
```
|
27 |
+
conda activate ncerl
|
28 |
+
```
|
29 |
|
30 |
+
## 快速开始
|
31 |
+
如果您想查看效果,可以通过
|
32 |
+
```
|
33 |
+
python app.py
|
34 |
+
```
|
35 |
+
后打开命令行显示连接互动查看。
|
36 |
|
37 |
+
也可以通过运行
|
38 |
+
```
|
39 |
+
python generate_and_play.py
|
40 |
+
```
|
41 |
+
后查看`models/example_policy/samples.png`查看生成效果。
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
+
## 开始训练
|
44 |
|
45 |
All training are launched by running `train.py` with option and arguments. For example, execute `python train.py ncesac --lbd 0.3 --m 5` will train NCERL with hyperparameters set as $\lambda = 0.3, m=5$.
|
46 |
Plot script is `plots.py`
|
|
|
49 |
* `python train.py sac`: to train a standard SAC as the policy for online game level generation
|
50 |
* `python train.py asyncsac`: to train a SAC with an asynchronous evaluation environment as the policy for online game level generation
|
51 |
* `python train.py ncesac`: to train an NCERL based on SAC as the policy for online game level generation
|
52 |
+
* `python train.py egsac`: to train an episodic generative SAC (see paper [*The fun facets of Mario: Multifaceted experience-driven PCG via reinforcement learning*](https://dl.acm.org/doi/abs/10.1145/3555858.3563282)) as the policy for online game level generation
|
53 |
* `python train.py pmoe`: to train an episodic generative SAC (see paper [*Probabilistic Mixture-of-Experts for Efficient Deep Reinforcement Learning*](https://arxiv.org/abs/2104.09122)) as the policy for online game level generation
|
54 |
* `python train.py sunrise`: to train a SUNRISE (see paper [*SUNRISE: A Simple Unified Framework for Ensemble Learning in Deep Reinforcement Learning*](https://proceedings.mlr.press/v139/lee21g.html)) as the policy for online game level generation
|
55 |
* `python train.py dvd`: to train a DvD-SAC (see paper [*Effective Diversity in Population Based Reinforcement Learning*](https://proceedings.neurips.cc/paper_files/paper/2020/hash/d1dc3a8270a6f9394f88847d7f0050cf-Abstract.html)) as the policy for online game level generation
|
56 |
|
57 |
For the training arguments, please refer to the help `python train.py [option] --help`
|
58 |
+
|
59 |
+
## 目录结构
|
60 |
+
```
|
61 |
+
NCERL-DIVERSE-PCG/
|
62 |
+
* analysis/
|
63 |
+
* generate.py 未使用
|
64 |
+
* tests.py 做evaluation使用
|
65 |
+
* media/ markdown素材文件
|
66 |
+
* models/
|
67 |
+
* example_policy/ 做生成展示使用
|
68 |
+
* smb/ 马里奥仿真以及图片资源数据
|
69 |
+
* src/
|
70 |
+
* ddpm/ ddpm模型相关目录
|
71 |
+
* drl/ drl模型、训练目录
|
72 |
+
* env/ 马里奥gym环境和reward function
|
73 |
+
* gan/ gan模型、训练目录
|
74 |
+
* olgen/ 在线生成环境与policy目录
|
75 |
+
* rlkit/ 强化学习使用部件目录
|
76 |
+
* smb/ 马里奥与仿真器交互组件以及多进程异步池组件
|
77 |
+
* utils/ 一些功能性文件
|
78 |
+
* training_data/ 训练数据
|
79 |
+
* README.md 当前文件
|
80 |
+
* app.py 用于gradio展示用途文件
|
81 |
+
* generate_and_play.py 用于非gradio展示文件
|
82 |
+
* train.py 训练文件
|
83 |
+
* test_ddpm.py 测试训练ddpm文件
|
84 |
+
* requirements.txt 环境依赖文件
|
85 |
+
```
|
86 |
+
|
analysis/tests.py
CHANGED
@@ -36,7 +36,6 @@ def evaluate_rewards(lvls, rfunc='default', dest_path='', parallel=1, eval_pool=
|
|
36 |
|
37 |
def evaluate_mnd(lvls, refs, parallel=2):
|
38 |
eval_pool = AsycSimltPool(parallel, verbose=False, refs=[str(ref) for ref in refs])
|
39 |
-
# m, _ = len(lvls), len(refs)
|
40 |
res = []
|
41 |
for lvl in lvls:
|
42 |
eval_pool.put('mnd_item', str(lvl))
|
@@ -49,7 +48,6 @@ def evaluate_mnd(lvls, refs, parallel=2):
|
|
49 |
def evaluate_mpd(lvls, parallel=2):
|
50 |
task_datas = [[] for _ in range(parallel)]
|
51 |
for i, (A, B) in enumerate(combinations(lvls, 2)):
|
52 |
-
# lvlA, lvlB = lvls[i * 2], lvls[i * 2 + 1]
|
53 |
task_datas[i % parallel].append((str(A), str(B)))
|
54 |
|
55 |
hms, dtws = [], []
|
@@ -73,7 +71,6 @@ def evaluate_gen_log(path, parallel=5):
|
|
73 |
step = name[4:]
|
74 |
rewards = [sum(item) for item in evaluate_rewards(lvls, rfunc_name, parallel=parallel)]
|
75 |
r_avg, r_std = np.mean(rewards), np.std(rewards)
|
76 |
-
# mpd_hm, mpd_dtw = evaluate_mpd(lvls, parallel=parallel)
|
77 |
mpd = evaluate_mpd(lvls, parallel=parallel)
|
78 |
line = [step, r_avg, r_std, mpd, '']
|
79 |
wrtr.writerow(line)
|
|
|
36 |
|
37 |
def evaluate_mnd(lvls, refs, parallel=2):
|
38 |
eval_pool = AsycSimltPool(parallel, verbose=False, refs=[str(ref) for ref in refs])
|
|
|
39 |
res = []
|
40 |
for lvl in lvls:
|
41 |
eval_pool.put('mnd_item', str(lvl))
|
|
|
48 |
def evaluate_mpd(lvls, parallel=2):
|
49 |
task_datas = [[] for _ in range(parallel)]
|
50 |
for i, (A, B) in enumerate(combinations(lvls, 2)):
|
|
|
51 |
task_datas[i % parallel].append((str(A), str(B)))
|
52 |
|
53 |
hms, dtws = [], []
|
|
|
71 |
step = name[4:]
|
72 |
rewards = [sum(item) for item in evaluate_rewards(lvls, rfunc_name, parallel=parallel)]
|
73 |
r_avg, r_std = np.mean(rewards), np.std(rewards)
|
|
|
74 |
mpd = evaluate_mpd(lvls, parallel=parallel)
|
75 |
line = [step, r_avg, r_std, mpd, '']
|
76 |
wrtr.writerow(line)
|
app.py
CHANGED
@@ -9,7 +9,6 @@ sys.path.append(path.dirname(path.abspath(__file__)))
|
|
9 |
|
10 |
|
11 |
from src.olgen.ol_generator import VecOnlineGenerator
|
12 |
-
# from src.olgen.olg_game import MarioOnlineGenGame
|
13 |
from src.olgen.olg_policy import RLGenPolicy
|
14 |
from src.smb.level import save_batch
|
15 |
from src.utils.filesys import getpath
|
@@ -21,7 +20,7 @@ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
|
21 |
|
22 |
def generate_and_play():
|
23 |
path = 'models/example_policy'
|
24 |
-
#
|
25 |
N, L = 8, 10
|
26 |
plc = RLGenPolicy.from_path(path, device)
|
27 |
generator = VecOnlineGenerator(plc, g_device=device)
|
@@ -29,14 +28,9 @@ def generate_and_play():
|
|
29 |
os.makedirs(fd, exist_ok=True)
|
30 |
|
31 |
lvls = generator.generate(N, L)
|
32 |
-
# save_batch(lvls, f'{path}/samples.lvls')
|
33 |
imgs = [lvl.to_img() for lvl in lvls]
|
34 |
return imgs
|
35 |
-
# make_img_sheet(imgs, 1, save_path=f'{path}/samples.png')
|
36 |
|
37 |
-
# # Play with the example policy model
|
38 |
-
# game = MarioOnlineGenGame(path)
|
39 |
-
# game.play()
|
40 |
|
41 |
|
42 |
with gr.Blocks(title="NCERL Demo") as demo:
|
|
|
9 |
|
10 |
|
11 |
from src.olgen.ol_generator import VecOnlineGenerator
|
|
|
12 |
from src.olgen.olg_policy import RLGenPolicy
|
13 |
from src.smb.level import save_batch
|
14 |
from src.utils.filesys import getpath
|
|
|
20 |
|
21 |
def generate_and_play():
|
22 |
path = 'models/example_policy'
|
23 |
+
# 使用example policy做生成
|
24 |
N, L = 8, 10
|
25 |
plc = RLGenPolicy.from_path(path, device)
|
26 |
generator = VecOnlineGenerator(plc, g_device=device)
|
|
|
28 |
os.makedirs(fd, exist_ok=True)
|
29 |
|
30 |
lvls = generator.generate(N, L)
|
|
|
31 |
imgs = [lvl.to_img() for lvl in lvls]
|
32 |
return imgs
|
|
|
33 |
|
|
|
|
|
|
|
34 |
|
35 |
|
36 |
with gr.Blocks(title="NCERL Demo") as demo:
|
generate_and_play.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import os
|
2 |
-
|
3 |
from src.olgen.ol_generator import VecOnlineGenerator
|
4 |
from src.olgen.olg_game import MarioOnlineGenGame
|
5 |
from src.olgen.olg_policy import RLGenPolicy
|
@@ -7,12 +7,14 @@ from src.smb.level import save_batch
|
|
7 |
from src.utils.filesys import getpath
|
8 |
from src.utils.img import make_img_sheet
|
9 |
|
|
|
|
|
10 |
if __name__ == '__main__':
|
11 |
path = 'models/example_policy'
|
12 |
# Generate with example policy model
|
13 |
N, L = 8, 10
|
14 |
-
plc = RLGenPolicy.from_path(path)
|
15 |
-
generator = VecOnlineGenerator(plc)
|
16 |
fd, _ = os.path.split(getpath(path))
|
17 |
os.makedirs(fd, exist_ok=True)
|
18 |
|
@@ -22,6 +24,7 @@ if __name__ == '__main__':
|
|
22 |
make_img_sheet(imgs, 1, save_path=f'{path}/samples.png')
|
23 |
|
24 |
# # Play with the example policy model
|
25 |
-
#
|
26 |
-
|
|
|
27 |
pass
|
|
|
1 |
import os
|
2 |
+
import torch
|
3 |
from src.olgen.ol_generator import VecOnlineGenerator
|
4 |
from src.olgen.olg_game import MarioOnlineGenGame
|
5 |
from src.olgen.olg_policy import RLGenPolicy
|
|
|
7 |
from src.utils.filesys import getpath
|
8 |
from src.utils.img import make_img_sheet
|
9 |
|
10 |
+
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
11 |
+
|
12 |
if __name__ == '__main__':
|
13 |
path = 'models/example_policy'
|
14 |
# Generate with example policy model
|
15 |
N, L = 8, 10
|
16 |
+
plc = RLGenPolicy.from_path(path, device=device)
|
17 |
+
generator = VecOnlineGenerator(plc, g_device=device)
|
18 |
fd, _ = os.path.split(getpath(path))
|
19 |
os.makedirs(fd, exist_ok=True)
|
20 |
|
|
|
24 |
make_img_sheet(imgs, 1, save_path=f'{path}/samples.png')
|
25 |
|
26 |
# # Play with the example policy model
|
27 |
+
# 请保证您的电脑上已经安装了jvm, 并且在命令行中输入java可以看到Java的信息
|
28 |
+
game = MarioOnlineGenGame(path)
|
29 |
+
game.play()
|
30 |
pass
|
media/banner.png
ADDED
models/example_policy/samples.lvls
CHANGED
@@ -1,135 +1,135 @@
|
|
1 |
-
|
2 |
-
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
3 |
-
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
4 |
-
|
5 |
-
------------------------------oo
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
;
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
----------------
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
---
|
33 |
-
|
34 |
;
|
35 |
-
|
36 |
-
|
37 |
-
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
--
|
49 |
-
|
50 |
-
|
51 |
;
|
52 |
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
53 |
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
54 |
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
55 |
-
|
56 |
-
------------------------------oo-----------------------------------------------------------------------------------------------o------------------------------------------------
|
57 |
-
---------------------------------o-----------------------------------------------------------------------------------------------------------------o----------------------------
|
58 |
-
----------------%----So--------------Q---------------QQQQ------------Qo--------------QQQQS-----#--------------------------------SooS--SS-----S-----SQQQ---------------o---------
|
59 |
-
----------------|------------------------------------------------------------------------------#--------------------------------------------------------------------------------
|
60 |
-
----------------|----------------------------------------------------------------------------###-------------K------------------------------------------------o-----------------
|
61 |
-
----------------|----------------------------------------------------------------------------###-------------2---------------------------------o--------------------------------
|
62 |
-
-------------oo-|----------------#--USoS-----US--#------------------------------Q-Q----QQ----###----------------------------------UQS------------------------US--------------o--
|
63 |
-
----------------|------TT-----K-TT---------------##----t---------------t---------------------###-------B---------------TT----#T----------------------------------------tt-------
|
64 |
-
---------------@|------TT-----U-TT-----K---------#---------------------t--------------------####-------B---------------TT----TT----------------------------------------tt-------
|
65 |
-
---gg----------g|-gggg-Tt---k-U-T--k-k----k------#-k--kk-----k-----k-gog----kkk---or--------####---k-gog----k-k---gggg-TT---kTT--------------------kgggg--k-k-----ggggott---kkk-
|
66 |
-
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX%%%%%-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX--XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
67 |
-
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-|XX--XXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
68 |
-
;
|
69 |
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
70 |
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
|
|
71 |
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
85 |
;
|
86 |
-
|
87 |
-
|
88 |
-
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
;
|
103 |
-
|
104 |
-
|
105 |
-
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
;
|
120 |
-
|
121 |
-
|
122 |
-
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
|
|
1 |
+
----------------------------------------------------------------------------------------------------------S---------------------------------------------------------------------
|
2 |
+
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
3 |
+
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
4 |
+
----------------------------------------------------------------------------------------------------------------------------------------------------------------SSS-------------
|
5 |
+
------------------------------oo---------------------------------------------------------------------------------------------------------------o--------------------------------
|
6 |
+
--------------oo----------------------------------------------------------------------------------------------------------------------------------------------------------------
|
7 |
+
-------SS---%%%%%----So----------S--QSSSoS---SS---------------------QQQQQSSS-SSS---------------o-------------------------------U--------------------QS--------------------------
|
8 |
+
-------------||-|---------------------------------------------------------------------------------------------------------------------------------------------------------------
|
9 |
+
-------------||-|--------------------------------------------K---------------------------------------------------------------oo-------------------------------------------------
|
10 |
+
-------------||-|--------------------------------------------2---------------------------------------------oo----------------------------------------------------------o--------
|
11 |
+
-------QSSS-%%%%|---------------Q#-----------SS-----------------------QQQSSSS@S--------------------------S%%%----------------@S---------------------USSS-----US--------tt-------
|
12 |
+
-------------||-|------TT-----K------------------------B------K--------------------------------------------|---------------------------tt----#-------------------------tt----T--
|
13 |
+
-------------||-|------TT-----U------------------------B---------------------------------------------------|-----------B---------------Tt----TT------------------------tt----T--
|
14 |
+
---t---------||-|-gggg-Tt---k-U-------k-k----------k-gog----k-k----k-------------------------#-------------|------gk---b----k-k---ggg--Tt----kT----k-oog--k-k----------tt---kkk-
|
15 |
+
XX%XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXXXXXXXXXXXX-XXXXXXXXXX-XXXXXXXXXXXXXXXXXXXX---%%%%%%%%|----XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
16 |
+
X-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXXXXXXXXXXXX-XXXXXXXXXX-XXXXXXXXXXXXXXXXXXXX----||||||-|----XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
17 |
;
|
18 |
+
--------------------------------------------------------------------------k-----------------------------------------------------------------------------------------------------
|
19 |
+
---------------------------------------------------------S----------------K-----------------------------------------------------------------------------------------------------
|
20 |
+
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
21 |
+
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
22 |
+
------------------------------oo-------------------------o----------------------------------------------------------------------------------------------------------------------
|
23 |
+
---------------------------------------------------------------------------------------------------------------o----------------------------------------------------------------
|
24 |
+
-------------------------------------Q----------------SSSSS--SSS----QQQQQ--------------------------------------------------------------QQ-----------SQo-------------------------
|
25 |
+
-------------------------------------------------------------------------------o------------------------------------------------------------------------------------------------
|
26 |
+
-------------------------------------o---------------------------------------------------------------o---------------------------------------oo--------------U------------------
|
27 |
+
-----------------------------------------------------------------T--------------------------------------------------------------------g-----------------------------------------
|
28 |
+
------------------------------------USSS-----US------------SSSS--T#--------------------------o------USSS-----@S--------tt-------#-#-####-----@S--------2-----U------------------
|
29 |
+
-----------------------tt-----T-TT-------------------------------T#----K---------------t--------T----------------------tt-------###--------------------K--K------------Tt----#--
|
30 |
+
-----------------------Tt-----T-T--------------------------------##--------------------t-----TT------------------------tt-------T#-----K---------------B---------------TT----TT-
|
31 |
+
---------------o--gggg-Tt---kkk-T--k-k-g--k---------------------TT#----k----k-k---kggg-b----kkg----k------k-------g----tt---kkg-T#-k------k--------k-gog--k-k-----gg---TT----TT-
|
32 |
+
XXXXXXXX----XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-------XXX---XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
33 |
+
XXXXXXXX-X-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX---X---XX----XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
34 |
;
|
35 |
+
-------K----------------------------------------------------------------------------------k---------------------------------------------------------------------------S---------
|
36 |
+
----------------------------------------------------------------------------------------------------------------------------------------------------------------------S---------
|
37 |
+
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
38 |
+
----------------------------------------------------------------------------------------------------------------SSSSSSSSSSSSSSSS------------------------------------------------
|
39 |
+
-----------------------------------------------o----------------------------------------------------------------S--------------S------------------------------------------------
|
40 |
+
-------------ooo------------------------------------------------------------------------------------------------SSS------ooo---S----------------------------------o-------------
|
41 |
+
----------%%%%%%-%-------------U----------------SSoSSSSSS----SSS--------------------QSQQQ----SSS------------------SSSSSSSSSSSSSS----QQSSSSS------------SS-SSSS---S%--------SS---
|
42 |
+
-----------||||--|------------------------------------------------------------------------------------------------------------------------------------------------|----TT-------
|
43 |
+
-----------||||--|-----------oo----------------------------------------------K-----------------------o---------------------------------------oo-------------------|----TT-------
|
44 |
+
---------ooo|||--|--------------------------------------------oo-------------2------------------------------------------------------------------------------------|----TT---SS--
|
45 |
+
------%%%%%%%%K--|-----------@o-----------------QSQSS---------SS--------------------Q--------USS----USSS----S@S-----------------------------S@S-------------S@S---|---###--S----
|
46 |
+
-------||||||-o--|-----t---------------TT----TT------------------------B------------------------T-----------------------------------------------------------------|--------S----
|
47 |
+
-------||||||%%%-|-----t---------------TT----TT------------------------B------------------------------------------------t-----------------------------------------|--------S----
|
48 |
+
---r---||||||-|--|kk---b-----gkg--gggg-TT---kTT-----ggg------------k-gog----k-k----k--------k------k------k-----T------tt------g--kggk--------------kk---------k--|-------oy----
|
49 |
+
%%%%%%%%%%%||-|-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXX--XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXXXXXXXXXXX-SS--SXX-X--
|
50 |
+
-|||||||||-||-|-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXXXXXXXXXXX---------X-X
|
51 |
;
|
52 |
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
53 |
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
54 |
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
55 |
+
----------------------------------------------------------------SSSS------------------------------------------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
57 |
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
58 |
+
--------------------------------------SS-----S------QQ------------------------------Q---QS---SSS----SQo-----------------------------QS----------------------------oo------------
|
59 |
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
60 |
+
-------------------------------------------------------------------------------------------------------------U------------------------------------------------------------------
|
61 |
+
-----------------------------------------------------------------------o---------o---------------------------------------------------------------------o------------------------
|
62 |
+
------------------------------------oo-------o-------2-------U---------tt-------Q#QQQ---Q----U---------2-----U----------------------USoS-----US--------tt----o---------------SSS
|
63 |
+
-----------------------tt----TT------------------------K---------------tt----T-------------------------K--K------------Tt----#-------------------------tt----#------------------
|
64 |
+
-----------------------Tt----TT------------------------B---------------tt----TT--------------#---------B---------------TT----TT------------------------tt----TT-----------------
|
65 |
+
--kk--------------gggg-Tt---kkT--------------------k-k-b----k-----g----tt---kkT---ok--------k------k-gog----k-----g----TT----TT----k-kog--k-k----------tt----k------------k-----
|
66 |
+
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXX
|
67 |
+
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXX
|
68 |
+
;
|
69 |
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
70 |
+
-------------------------------------------------------------K-K-------------------------------------------------SS-------------------------------------------------------------
|
71 |
+
---------------------------------------------------------------o----------------------------------------------------------------------------------------------------------------
|
72 |
+
------------------------------%%-----------------------------%%t------------------------------------------------SSSSSSSSSSSSSSSS------------------------------------------------
|
73 |
+
------------------------------||-----------------------------||o------------------------------------------------S---SSS------S-------------------------------------------------o
|
74 |
+
------------------------------||-----------------------------||%------------------------------------------------SSSSSSSSSSSoSSSS------------------------------------------------
|
75 |
+
------------------------S-----|o---------------U------%%%%%--|||S--SSS-------SSS--------------------------------SSSSSSSSSSSSSSSS-----QQ---------%---SS--------------------------
|
76 |
+
------------------------------||-----------------------|||---|||--------------------------------------------------------------------------------|-------------------------------
|
77 |
+
------------------------------||-----o-------oo--------|||---|||-----------------------------K--------------------------------------------------|------------Ko-----------------
|
78 |
+
------------------------------||-----------------------|||-o-SS|-g----------------------------------------------ooooooo--------S----------------|-------------------------------
|
79 |
+
---------T--------------------||-----S-------@S--------|||%%%--|QSQSSSSSSSS-S@S%-------------U---------------U--SSSSSSS--------S-#--------------|----S-------US-----------------
|
80 |
+
---------TT---TT--------------||----------K------------|||-|---|---------------|-----------------------BB----#-------------------------t--------|------K--K------------TT----TT-
|
81 |
+
--B------TT---TT--------------||-------B---------------|||-|---|---------------|-----------------------Tt----TT---------------------------------|------B---------------TT----TT-
|
82 |
+
--b------TT---TT--------------|r--gk---b----k-k--------|||-|---r---k------k----|--kk---k----kkk---ggg--Tt---kkg------ggg-------g---k--kk----k---|--k-gog--k-k-----gggg-TT---kTT-
|
83 |
+
XXXXXXXXXXXXXXXX---------X---X%%XXXXXXXXXXXXXXXX-------|||-|%%%%XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
84 |
+
XXXXXXXXXXXXXXXX--------XX---X|XXXXXXXXXXXXXXXXX-------|||-|-||-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
85 |
;
|
86 |
+
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
87 |
+
--------------------------------------------------------------------------------------------------K-----------------------------------------------------------------------------
|
88 |
+
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
89 |
+
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
90 |
+
-----------------------------------------------------------------------------------------------o-------------------------------o------------------------------------------------
|
91 |
+
-------------------------------------------------------------------------------o---------------------------------------------------------------------------------ooo-----ooooo--
|
92 |
+
-------QQ-----------Q-QQQ------------SSSQS---SSS----SQQ----------------------%%%----QQQ------------------------2------SSSS-----------Qo-------------QQQQSSS--S--%%--------------
|
93 |
+
------------------------------------------------------------------------------|------------------B---------------------------o----------------------------------||--------------
|
94 |
+
-------------------------------------------------------K----------------------|------------------------------------------------------------------------------oo-||--------------
|
95 |
+
------------------------------------------------------------------------------|-----------------o-o-------------------------------------------------------------||--------------
|
96 |
+
-------------------#------------------Q--------------2-2-----U---------------oo-----------------------------------------------------------------------QQQ---S@S-||---2---%%%----
|
97 |
+
--------------------------------------t----------------K--K-------------------|--------@----------T------o-----------------------------B------------------------||--------|-----
|
98 |
+
--------------------------------------t------#---------B----------------------|--------tt---------TT----TT-----------------------------B------------------------||-----K--|----T
|
99 |
+
------------k------k--------k--------kk------#-----k-gog----k-k---------------##---ygg-tt---k-----TT----TTT-k--------kk------------k-gog----kkk----k-k------k---|g-ggg-b--|----o
|
100 |
+
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX------%XXXX-XXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX--XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
101 |
+
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX------|-XXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX--XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
102 |
;
|
103 |
+
------------------------------------------------------------------------------------------------------------------------------------------S-------------------------------------
|
104 |
+
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
105 |
+
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
106 |
+
------------------------------------------------------------------------------------------------SSSSSSS---------SS----S---SSSSSS------------------------------------------------
|
107 |
+
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------o
|
108 |
+
-----------------------------------------------------------------------------------------------o--------------------------------------------------------------------------------
|
109 |
+
S----SQQQ-----------------------SSSSSSSS-----SSS------------------------------------------------------------------------------------------------%---SS--------------------------
|
110 |
+
------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------
|
111 |
+
---------------------------------------------o---------------K-----------------------o------------------------------------------------TU--------|----o-------oo-----------------
|
112 |
+
-----------------------o--------o----------------------------2-------------------------------------------------------S----------------T------o--|-------------------------------
|
113 |
+
-SSS-SSSSS--SU---------tt--------------------@SS-----------------------------U------USSS-----US--------t-----------------------------S@#--------|----SS------@S--------------o--
|
114 |
+
-----------------------tt----T-------------------------B---------------TT----#--T----------------------tt----------------T------S---------------|----------------------TT----#--
|
115 |
+
-----------------------tt----TT------------------------B---------------TT----TT-----------------------ttt----T----------TTT---------------K-----|----------------------TT----TT-
|
116 |
+
---k--g-----------ggg--tt---kkT----kgggg--k-k------k-gog----kkk---ggg--TT---kTT----k------k-----------ttt---kk----------TTT---------------U-----|--k-g-g--k-k-----gggg-TT---kTg-
|
117 |
+
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX---XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
118 |
+
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX---XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
119 |
;
|
120 |
+
---------------------------------------------------------------------------------------------------K--S---S---------------------------------------------------------------------
|
121 |
+
--------------------------------------------------------------------------------------------------SS----------------------------------------------------------------------------
|
122 |
+
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
123 |
+
--------------------------------------------------------------------------------------------------------------ooSS--------------------------------------------------------------
|
124 |
+
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
125 |
+
-----------------------------------------------------------------o--------------------------------------------------------------------------------o------------o----------------
|
126 |
+
-------------------------------------Q--------------QQQQQS------------------------------oo-----U----------------------SSSS----to----------------%%%-----------------QQ---------U
|
127 |
+
-------------------------------------------------------------------------------------------------------o----------------------K------------------|------------------------------
|
128 |
+
-----------------------------------------------------------------------------K---------------o-----------------------------------------------K---|----------------------------o-
|
129 |
+
-----gg----------------------------------------------------------------------U-------------------------------SS------------------------------U---|------------------------------
|
130 |
+
Q-#####-----------------------------USSS-----US--------QQ-----------------------------------S@S-------------------------------Q------------------|---------------------------US-
|
131 |
+
-----------------------tt----TT-TT-------------------------------------B------K----------------------------------------------##--------B------K--|------------------------------
|
132 |
+
-----------------------Tt----TT-T--------------------------------------B----------------------------------------------------###--------B---------|------------T-----------------
|
133 |
+
------------k-----gggg-Tt---kkT-T--k-k-g--k--------------------k---k-gog----k-k----k-------------------------r--------------##-----k-gog----k-k--|-----------oro--gggg-g----k---
|
134 |
+
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX--XX----XX--XXXX--------%%%%%%%%XXXXXXXXXX--XXXXXXXXXXXXXXXXXXXX-|XXXXXXXXXXX%%%XXXXXXXXXXXXXXXX
|
135 |
+
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX--XX---XXX--XXXX---------||||||-XXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXX-|XXXX-X-XXX--|-XXXXXXXXXXXXXXXX
|
models/example_policy/samples.png
CHANGED
plots.py
DELETED
@@ -1,733 +0,0 @@
|
|
1 |
-
import glob
|
2 |
-
import json
|
3 |
-
import os
|
4 |
-
import re
|
5 |
-
|
6 |
-
import numpy as np
|
7 |
-
import pandas as pds
|
8 |
-
import matplotlib
|
9 |
-
import matplotlib.pyplot as plt
|
10 |
-
from math import sqrt
|
11 |
-
import torch
|
12 |
-
from root import PRJROOT
|
13 |
-
from sklearn.manifold import TSNE
|
14 |
-
from itertools import product, chain
|
15 |
-
# from src.drl.drl_uses import load_cfgs
|
16 |
-
from src.gan.gankits import get_decoder, process_onehot
|
17 |
-
from src.gan.gans import nz
|
18 |
-
from src.smb.level import load_batch, hamming_dis, lvlhcat
|
19 |
-
from src.utils.datastruct import RingQueue
|
20 |
-
from src.utils.filesys import load_dict_json, getpath
|
21 |
-
from src.utils.img import make_img_sheet
|
22 |
-
from torch.distributions import Normal
|
23 |
-
|
24 |
-
matplotlib.rcParams["axes.formatter.limits"] = (-5, 5)
|
25 |
-
|
26 |
-
|
27 |
-
def print_compare_tab():
|
28 |
-
rand_lgp, rand_fhp, rand_divs = load_dict_json(
|
29 |
-
'test_data/rand_policy/performance.csv', 'lgp', 'fhp', 'diversity'
|
30 |
-
)
|
31 |
-
rand_performance = {'lgp': rand_lgp, 'fhp': rand_fhp, 'diversity': rand_divs}
|
32 |
-
|
33 |
-
def _print_line(_data, minimise=False):
|
34 |
-
means = _data.mean(axis=-1)
|
35 |
-
stds = _data.std(axis=-1)
|
36 |
-
max_i, min_i = np.argmax(means), np.argmin(means)
|
37 |
-
mean_str_content = [*map(lambda x: '%.4g' % x, _data.mean(axis=-1))]
|
38 |
-
std_str_content = [*map(lambda x: '$\pm$%.3g' % x, _data.std(axis=-1))]
|
39 |
-
if minimise:
|
40 |
-
mean_str_content[min_i] = r'\textbf{%s}' % mean_str_content[min_i]
|
41 |
-
mean_str_content[max_i] = r'\textit{%s}' % mean_str_content[max_i]
|
42 |
-
std_str_content[min_i] = r'\textbf{%s}' % std_str_content[min_i]
|
43 |
-
std_str_content[max_i] = r'\textit{%s}' % std_str_content[max_i]
|
44 |
-
else:
|
45 |
-
mean_str_content[max_i] = r'\textbf{%s}' % mean_str_content[max_i]
|
46 |
-
mean_str_content[min_i] = r'\textit{%s}' % mean_str_content[min_i]
|
47 |
-
std_str_content[max_i] = r'\textbf{%s}' % std_str_content[max_i]
|
48 |
-
std_str_content[min_i] = r'\textit{%s}' % std_str_content[min_i]
|
49 |
-
print(' &', ' & '.join(mean_str_content), r'\\')
|
50 |
-
print(' & &', ' & '.join(std_str_content), r'\\')
|
51 |
-
pass
|
52 |
-
|
53 |
-
def _print_block(_task):
|
54 |
-
fds = [
|
55 |
-
f'sac/{_task}', f'egsac/{_task}', f'asyncsac/{_task}',
|
56 |
-
f'pmoe/{_task}', f'dvd/{_task}', f'sunrise/{_task}',
|
57 |
-
f'varpm-{_task}/l0.0_m5', f'varpm-{_task}/l0.1_m5', f'varpm-{_task}/l0.2_m5',
|
58 |
-
f'varpm-{_task}/l0.3_m5', f'varpm-{_task}/l0.4_m5', f'varpm-{_task}/l0.5_m5'
|
59 |
-
]
|
60 |
-
rewards, divs = [], []
|
61 |
-
for fd in fds:
|
62 |
-
rewards.append([])
|
63 |
-
divs.append([])
|
64 |
-
# print(getpath())
|
65 |
-
for path in glob.glob(getpath('test_data', fd, '**', 'performance.csv'), recursive=True):
|
66 |
-
reward, div = load_dict_json(path, 'reward', 'diversity')
|
67 |
-
rewards[-1].append(reward)
|
68 |
-
divs[-1].append(div)
|
69 |
-
rewards = np.array(rewards)
|
70 |
-
divs = np.array(divs)
|
71 |
-
|
72 |
-
print(' & \\multirow{2}{*}{Reward}')
|
73 |
-
_print_line(rewards)
|
74 |
-
print(' \\cline{2-14}')
|
75 |
-
print(' & \\multirow{2}{*}{Diversity}')
|
76 |
-
_print_line(divs)
|
77 |
-
print(' \\cline{2-14}')
|
78 |
-
print(' & \\multirow{2}{*}{G-mean}')
|
79 |
-
gmean = np.sqrt(rewards * divs)
|
80 |
-
_print_line(gmean)
|
81 |
-
|
82 |
-
print(' \\cline{2-14}')
|
83 |
-
print(' & \\multirow{2}{*}{N-rank}')
|
84 |
-
r_rank = np.zeros_like(rewards.flatten())
|
85 |
-
r_rank[np.argsort(-rewards.flatten())] = np.linspace(1, len(r_rank), len(r_rank))
|
86 |
-
|
87 |
-
d_rank = np.zeros_like(divs.flatten())
|
88 |
-
d_rank[np.argsort(-divs.flatten())] = np.linspace(1, len(r_rank), len(r_rank))
|
89 |
-
n_rank = (r_rank.reshape([12, 5]) + d_rank.reshape([12, 5])) / (2 * 5)
|
90 |
-
_print_line(n_rank, True)
|
91 |
-
|
92 |
-
print(' \\multirow{8}{*}{MarioPuzzle}')
|
93 |
-
_print_block('fhp')
|
94 |
-
print(' \\midrule')
|
95 |
-
print(' \\multirow{8}{*}{MultiFacet}')
|
96 |
-
_print_block('lgp')
|
97 |
-
pass
|
98 |
-
|
99 |
-
def print_compare_tab_nonrl():
|
100 |
-
# rand_lgp, rand_fhp, rand_divs = load_dict_json(
|
101 |
-
# 'test_data/rand_policy/performance.csv', 'lgp', 'fhp', 'diversity'
|
102 |
-
# )
|
103 |
-
# rand_performance = {'lgp': rand_lgp, 'fhp': rand_fhp, 'diversity': rand_divs}
|
104 |
-
|
105 |
-
def _print_line(_data, minimise=False):
|
106 |
-
means = _data.mean(axis=-1)
|
107 |
-
stds = _data.std(axis=-1)
|
108 |
-
max_i, min_i = np.argmax(means), np.argmin(means)
|
109 |
-
mean_str_content = [*map(lambda x: '%.4g' % x, _data.mean(axis=-1))]
|
110 |
-
std_str_content = [*map(lambda x: '$\pm$%.3g' % x, _data.std(axis=-1))]
|
111 |
-
if minimise:
|
112 |
-
mean_str_content[min_i] = r'\textbf{%s}' % mean_str_content[min_i]
|
113 |
-
mean_str_content[max_i] = r'\textit{%s}' % mean_str_content[max_i]
|
114 |
-
std_str_content[min_i] = r'\textbf{%s}' % std_str_content[min_i]
|
115 |
-
std_str_content[max_i] = r'\textit{%s}' % std_str_content[max_i]
|
116 |
-
else:
|
117 |
-
mean_str_content[max_i] = r'\textbf{%s}' % mean_str_content[max_i]
|
118 |
-
mean_str_content[min_i] = r'\textit{%s}' % mean_str_content[min_i]
|
119 |
-
std_str_content[max_i] = r'\textbf{%s}' % std_str_content[max_i]
|
120 |
-
std_str_content[min_i] = r'\textit{%s}' % std_str_content[min_i]
|
121 |
-
print(' &', ' & '.join(mean_str_content), r'\\')
|
122 |
-
print(' & &', ' & '.join(std_str_content), r'\\')
|
123 |
-
pass
|
124 |
-
|
125 |
-
def _print_block(_task):
|
126 |
-
fds = [
|
127 |
-
f'GAN-{_task}', f'DDPM-{_task}',
|
128 |
-
f'varpm-{_task}/l0.0_m5', f'varpm-{_task}/l0.1_m5', f'varpm-{_task}/l0.2_m5',
|
129 |
-
f'varpm-{_task}/l0.3_m5', f'varpm-{_task}/l0.4_m5', f'varpm-{_task}/l0.5_m5'
|
130 |
-
]
|
131 |
-
rewards, divs = [], []
|
132 |
-
for fd in fds:
|
133 |
-
rewards.append([])
|
134 |
-
divs.append([])
|
135 |
-
# print(getpath())
|
136 |
-
for path in glob.glob(getpath('test_data', fd, '**', 'performance.csv'), recursive=True):
|
137 |
-
reward, div = load_dict_json(path, 'reward', 'diversity')
|
138 |
-
rewards[-1].append(reward)
|
139 |
-
divs[-1].append(div)
|
140 |
-
rewards = np.array(rewards)
|
141 |
-
divs = np.array(divs)
|
142 |
-
|
143 |
-
print(' & \\multirow{2}{*}{Reward}')
|
144 |
-
_print_line(rewards)
|
145 |
-
print(' \\cline{2-10}')
|
146 |
-
print(' & \\multirow{2}{*}{Diversity}')
|
147 |
-
_print_line(divs)
|
148 |
-
print(' \\cline{2-10}')
|
149 |
-
# print(' & \\multirow{2}{*}{G-mean}')
|
150 |
-
# gmean = np.sqrt(rewards * divs)
|
151 |
-
# _print_line(gmean)
|
152 |
-
#
|
153 |
-
# print(' \\cline{2-10}')
|
154 |
-
# print(' & \\multirow{2}{*}{N-rank}')
|
155 |
-
# r_rank = np.zeros_like(rewards.flatten())
|
156 |
-
# r_rank[np.argsort(-rewards.flatten())] = np.linspace(1, len(r_rank), len(r_rank))
|
157 |
-
#
|
158 |
-
# d_rank = np.zeros_like(divs.flatten())
|
159 |
-
# d_rank[np.argsort(-divs.flatten())] = np.linspace(1, len(r_rank), len(r_rank))
|
160 |
-
# n_rank = (r_rank.reshape([8, 5]) + d_rank.reshape([8, 5])) / (2 * 5)
|
161 |
-
# _print_line(n_rank, True)
|
162 |
-
|
163 |
-
print(' \\multirow{4}{*}{MarioPuzzle}')
|
164 |
-
_print_block('fhp')
|
165 |
-
print(' \\midrule')
|
166 |
-
print(' \\multirow{4}{*}{MultiFacet}')
|
167 |
-
_print_block('lgp')
|
168 |
-
pass
|
169 |
-
|
170 |
-
def plot_cmp_learning_curves(task, save_path='', title=''):
|
171 |
-
plt.style.use('seaborn')
|
172 |
-
colors = [plt.plot([0, 1], [-1000, -1000])[0].get_color() for _ in range(6)]
|
173 |
-
plt.cla()
|
174 |
-
plt.style.use('default')
|
175 |
-
|
176 |
-
# colors = ('#5D2CAB', '#005BD4', '#007CE4', '#0097DD', '#00ADC4', '#00C1A5')
|
177 |
-
def _get_algo_data(fd):
|
178 |
-
res = []
|
179 |
-
for i in range(1, 6):
|
180 |
-
path = getpath(fd, f't{i}', 'step_tests.csv')
|
181 |
-
try:
|
182 |
-
data = pds.read_csv(path)
|
183 |
-
trajectory = [
|
184 |
-
[float(item['step']), float(item['r-avg']), float(item['diversity'])]
|
185 |
-
for _, item in data.iterrows()
|
186 |
-
]
|
187 |
-
trajectory.sort(key=lambda x: x[0])
|
188 |
-
res.append(trajectory)
|
189 |
-
if len(trajectory) != 26:
|
190 |
-
print('Not complete (%d)/26:' % len(trajectory), path)
|
191 |
-
except FileNotFoundError:
|
192 |
-
print(path)
|
193 |
-
res = np.array(res)
|
194 |
-
# rdsum = res[:, :, 1] + res[:, :, 2]
|
195 |
-
gmean = np.sqrt(res[:, :, 1] * res[:, :, 2])
|
196 |
-
steps = res[0, :, 0]
|
197 |
-
# r_avgs = np.mean(res[:, :, 1], axis=0)
|
198 |
-
# r_stds = np.std(res[:, :, 1], axis=0)
|
199 |
-
# divs = np.mean(res[:, :, 2], axis=0)
|
200 |
-
# div_std = np.std(res[:, :, 2], axis=0)
|
201 |
-
_performances = {
|
202 |
-
'reward': (np.mean(res[:, :, 1], axis=0), np.std(res[:, :, 1], axis=0)),
|
203 |
-
'diversity': (np.mean(res[:, :, 2], axis=0), np.std(res[:, :, 2], axis=0)),
|
204 |
-
# 'rdsum': (np.mean(rdsum, axis=0), np.std(rdsum, axis=0)),
|
205 |
-
'gmean': (np.mean(gmean, axis=0), np.std(gmean, axis=0)),
|
206 |
-
}
|
207 |
-
# print(_performances['gmean'])
|
208 |
-
return steps, _performances
|
209 |
-
|
210 |
-
def _plot_criterion(_ax, _criterion):
|
211 |
-
i, j, k = 0, 0, 0
|
212 |
-
for algo, (steps, _performances) in performances.items():
|
213 |
-
avgs, stds = _performances[_criterion]
|
214 |
-
if '\lambda' in algo:
|
215 |
-
ls = '-'
|
216 |
-
_c = colors[i]
|
217 |
-
i += 1
|
218 |
-
elif algo in {'SAC', 'EGSAC', 'ASAC'}:
|
219 |
-
ls = ':'
|
220 |
-
_c = colors[j]
|
221 |
-
j += 1
|
222 |
-
else:
|
223 |
-
ls = '--'
|
224 |
-
_c = colors[j]
|
225 |
-
j += 1
|
226 |
-
_ax.plot(steps, avgs, color=_c, label=algo, ls=ls)
|
227 |
-
_ax.fill_between(steps, avgs - stds, avgs + stds, color=_c, alpha=0.15)
|
228 |
-
_ax.grid(False)
|
229 |
-
# plt.plot(steps, avgs, label=algo)
|
230 |
-
# plt.plot(_performances, label=algo)
|
231 |
-
pass
|
232 |
-
_ax.set_xlabel('Time step')
|
233 |
-
|
234 |
-
fig, ax = plt.subplots(1, 3, figsize=(9.6, 3.2), dpi=250, width_ratios=[1, 1, 1])
|
235 |
-
# fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(5, 4), dpi=256)
|
236 |
-
# fig, ax1 = plt.subplots(1, 1, figsize=(8, 3), dpi=256)
|
237 |
-
# ax2 = ax1.twinx()
|
238 |
-
# fig = plt.plot(figsize=(4, 3), dpi=256)
|
239 |
-
performances = {
|
240 |
-
'SUNRISE': _get_algo_data(f'test_data/sunrise/{task}'),
|
241 |
-
'$\lambda$=0.0': _get_algo_data(f'test_data/varpm-{task}/l0.0_m5'),
|
242 |
-
'DvD': _get_algo_data(f'test_data/dvd/{task}'),
|
243 |
-
'$\lambda$=0.1': _get_algo_data(f'test_data/varpm-{task}/l0.1_m5'),
|
244 |
-
'PMOE': _get_algo_data(f'test_data/pmoe/{task}'),
|
245 |
-
'$\lambda$=0.2': _get_algo_data(f'test_data/varpm-{task}/l0.2_m5'),
|
246 |
-
'SAC': _get_algo_data(f'test_data/sac/{task}'),
|
247 |
-
'$\lambda$=0.3': _get_algo_data(f'test_data/varpm-{task}/l0.3_m5'),
|
248 |
-
'EGSAC': _get_algo_data(f'test_data/egsac/{task}'),
|
249 |
-
'$\lambda$=0.4': _get_algo_data(f'test_data/varpm-{task}/l0.4_m5'),
|
250 |
-
'ASAC': _get_algo_data(f'test_data/asyncsac/{task}'),
|
251 |
-
'$\lambda$=0.5': _get_algo_data(f'test_data/varpm-{task}/l0.5_m5'),
|
252 |
-
}
|
253 |
-
# _plot_algo(*_get_algo_data(glob.glob(getpath('test_data/SAC', '**', 'step_tests.csv'))), 'SAC')
|
254 |
-
# _plot_algo(*_get_algo_data(glob.glob(getpath('test_data/EGSAC', '**', 'step_tests.csv'))), 'EGSAC')
|
255 |
-
# _plot_algo(*_get_algo_data(glob.glob(getpath('test_data/AsyncSAC', '**', 'step_tests.csv'))), 'AsyncSAC')
|
256 |
-
# _plot_algo(*_get_algo_data(glob.glob(getpath('test_data/SUNRISE', '**', 'step_tests.csv'))), 'SUNRISE')
|
257 |
-
# _plot_algo(*_get_algo_data(glob.glob(getpath('test_data/DvD-ES', '**', 'step_tests.csv'))), 'DvD-ES')
|
258 |
-
# _plot_algo(*_get_algo_data(glob.glob(getpath('test_data/lbd-m-crosstest/l0.04_m5', '**', 'step_tests.csv'))), 'NCESAC')
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
_plot_criterion(ax[0], 'reward')
|
263 |
-
_plot_criterion(ax[1], 'diversity')
|
264 |
-
# _plot_criterion(ax[2], 'rdsum')
|
265 |
-
_plot_criterion(ax[2], 'gmean')
|
266 |
-
# ax[0].set_title(f'{title} reward')
|
267 |
-
ax[0].set_title(f'Cumulative Reward')
|
268 |
-
ax[1].set_title('Diversity Score')
|
269 |
-
# ax[2].set_title('Summation')
|
270 |
-
ax[2].set_title('G-mean')
|
271 |
-
# plt.title(title)
|
272 |
-
|
273 |
-
lines, labels = fig.axes[-1].get_legend_handles_labels()
|
274 |
-
fig.suptitle(title, fontsize=14)
|
275 |
-
plt.tight_layout(pad=0.5)
|
276 |
-
if save_path:
|
277 |
-
plt.savefig(getpath(save_path))
|
278 |
-
else:
|
279 |
-
plt.show()
|
280 |
-
|
281 |
-
plt.cla()
|
282 |
-
plt.figure(figsize=(9.6, 2.4), dpi=250)
|
283 |
-
plt.grid(False)
|
284 |
-
plt.axis('off')
|
285 |
-
plt.yticks([1.0])
|
286 |
-
plt.legend(
|
287 |
-
lines, labels, loc='lower center', ncol=6, edgecolor='white', fontsize=15,
|
288 |
-
columnspacing=0.8, borderpad=0.16, labelspacing=0.2, handlelength=2.4, handletextpad=0.3
|
289 |
-
)
|
290 |
-
plt.tight_layout(pad=0.5)
|
291 |
-
plt.show()
|
292 |
-
pass
|
293 |
-
|
294 |
-
def plot_crosstest_scatters(rfunc, xrange=None, yrange=None, title=''):
|
295 |
-
def get_pareto():
|
296 |
-
all_points = list(chain(*scatter_groups.values())) + cmp_points
|
297 |
-
res = []
|
298 |
-
for p in all_points:
|
299 |
-
non_dominated = True
|
300 |
-
for q in all_points:
|
301 |
-
if q[0] >= p[0] and q[1] >= p[1] and (q[0] > p[0] or q[1] > p[1]):
|
302 |
-
non_dominated = False
|
303 |
-
break
|
304 |
-
if non_dominated:
|
305 |
-
res.append(p)
|
306 |
-
res.sort(key=lambda item:item[0])
|
307 |
-
return np.array(res)
|
308 |
-
def _hex_color(_c):
|
309 |
-
return
|
310 |
-
scatter_groups = {}
|
311 |
-
all_lbd = set()
|
312 |
-
# Initialise
|
313 |
-
plt.style.use('seaborn-v0_8-dark-palette')
|
314 |
-
# plt.figure(figsize=(4, 4), dpi=256)
|
315 |
-
plt.figure(figsize=(2.5, 2.5), dpi=256)
|
316 |
-
plt.axes().set_axisbelow(True)
|
317 |
-
|
318 |
-
# Competitors' performances
|
319 |
-
cmp_folders = ['asyncsac', 'egsac', 'sac', 'sunrise', 'dvd', 'pmoe']
|
320 |
-
cmp_names = ['ASAC', 'EGSAC', 'SAC', 'SUNRISE', 'DvD', 'PMOE']
|
321 |
-
cmp_labels = ['A', 'E', 'S', 'R', 'D', 'M']
|
322 |
-
cmp_markers = ['2', 'x', '+', 'o', '*', 'D']
|
323 |
-
cmp_sizes = [42, 20, 32, 16, 24, 10, 10]
|
324 |
-
cmp_points = []
|
325 |
-
for name, folder, label, mk, s in zip(cmp_names, cmp_folders, cmp_labels, cmp_markers, cmp_sizes):
|
326 |
-
path_fmt = getpath('test_data', folder, rfunc, '*', 'performance.csv')
|
327 |
-
# print(path_fmt)
|
328 |
-
xs, ys = [], []
|
329 |
-
for path in glob.glob(path_fmt, recursive=True):
|
330 |
-
# print(path)
|
331 |
-
try:
|
332 |
-
x, y = load_dict_json(path, 'reward', 'diversity')
|
333 |
-
xs.append(x)
|
334 |
-
ys.append(y)
|
335 |
-
cmp_points.append([x, y])
|
336 |
-
# plt.text(x, y, label, size=7, weight='bold', va='center', ha='center', color='#202020')
|
337 |
-
except FileNotFoundError:
|
338 |
-
print(path)
|
339 |
-
if label in {'A', 'E', 'S'}:
|
340 |
-
plt.scatter(xs, ys, marker=mk, zorder=2, s=s, label=name, color='#202020')
|
341 |
-
else:
|
342 |
-
plt.scatter(
|
343 |
-
xs, ys, marker=mk, zorder=2, s=s, label=name, color=[0., 0., 0., 0.],
|
344 |
-
edgecolors='#202020', linewidths=1
|
345 |
-
)
|
346 |
-
# NCESAC performances
|
347 |
-
for path in glob.glob(getpath('test_data', f'varpm-{rfunc}', '**', 'performance.csv'), recursive=True):
|
348 |
-
try:
|
349 |
-
x, y = load_dict_json(path, 'reward', 'diversity')
|
350 |
-
key = path.split('\\')[-3]
|
351 |
-
_, mtxt = key.split('_')
|
352 |
-
ltxt, _ = key.split('_')
|
353 |
-
lbd = float(ltxt[1:])
|
354 |
-
# if mtxt in {'m2', 'm3', 'm4'}:
|
355 |
-
# continue
|
356 |
-
all_lbd.add(lbd)
|
357 |
-
if key not in scatter_groups.keys():
|
358 |
-
scatter_groups[key] = []
|
359 |
-
scatter_groups[key].append([x, y])
|
360 |
-
except Exception as e:
|
361 |
-
print(path)
|
362 |
-
print(e)
|
363 |
-
|
364 |
-
palette = plt.get_cmap('seismic')
|
365 |
-
color_x = [0.2, 0.33, 0.4, 0.61, 0.67, 0.79]
|
366 |
-
colors = {lbd: matplotlib.colors.to_hex(c) for c, lbd in zip(palette(color_x), sorted(all_lbd))}
|
367 |
-
colors = {0.0: '#150080', 0.1: '#066598', 0.2: '#01E499', 0.3: '#9FD40C', 0.4: '#F3B020', 0.5: '#FA0000'}
|
368 |
-
for lbd in sorted(all_lbd): plt.plot([-20], [-20], label=f'$\\lambda={lbd:.1f}$', lw=6, c=colors[lbd])
|
369 |
-
markers = {2: 'o', 3: '^', 4: 'D', 5: 'p', 6: 'h'}
|
370 |
-
msizes = {2: 25, 3: 25, 4: 16, 5: 28, 6: 32}
|
371 |
-
for key, group in scatter_groups.items():
|
372 |
-
ltxt, mtxt = key.split('_')
|
373 |
-
l = float(ltxt[1:])
|
374 |
-
m = int(mtxt[1:])
|
375 |
-
arr = np.array(group)
|
376 |
-
plt.scatter(
|
377 |
-
arr[:, 0], arr[:, 1], marker=markers[m], s=msizes[m], color=[0., 0., 0., 0.], zorder=2,
|
378 |
-
edgecolors=colors[l], linewidths=1
|
379 |
-
)
|
380 |
-
|
381 |
-
plt.xlim(xrange)
|
382 |
-
plt.ylim(yrange)
|
383 |
-
# plt.xlabel('Task Reward')
|
384 |
-
# plt.ylabel('Diversity')
|
385 |
-
# plt.legend(ncol=2)
|
386 |
-
# plt.legend(
|
387 |
-
# ncol=2, loc='lower left', columnspacing=1.2, borderpad=0.0,
|
388 |
-
# handlelength=1, handletextpad=0.5, framealpha=0.
|
389 |
-
# )
|
390 |
-
pareto = get_pareto()
|
391 |
-
plt.plot(
|
392 |
-
pareto[:, 0], pareto[:, 1], color='black', alpha=0.18, lw=6, zorder=3,
|
393 |
-
solid_joinstyle='round', solid_capstyle='round'
|
394 |
-
)
|
395 |
-
# plt.plot([88, 98, 98, 88, 88], [35, 35, 0.2, 0.2, 35], color='black', alpha=0.3, lw=1.5)
|
396 |
-
# plt.xticks(fontsize=16)
|
397 |
-
# plt.yticks(fontsize=16)
|
398 |
-
# plt.xticks([(1+space) * (m-mlow) + 0.5 for m in ms], [f'm={m}' for m in ms])
|
399 |
-
plt.title(title)
|
400 |
-
plt.grid()
|
401 |
-
plt.tight_layout(pad=0.4)
|
402 |
-
plt.show()
|
403 |
-
|
404 |
-
def plot_varpm_heat(task, name):
|
405 |
-
def _get_score(m, l):
|
406 |
-
fd = getpath('test_data', f'varpm-{task}', f'l{l}_m{m}')
|
407 |
-
rewards, divs = [], []
|
408 |
-
for i in range(5):
|
409 |
-
reward, div = load_dict_json(f'{fd}/t{i+1}/performance.csv', 'reward', 'diversity')
|
410 |
-
rewards.append(reward)
|
411 |
-
divs.append(div)
|
412 |
-
gmean = [sqrt(r * d) for r, d in zip(rewards, divs)]
|
413 |
-
return np.mean(rewards), np.std(rewards), \
|
414 |
-
np.mean(divs), np.std(divs), \
|
415 |
-
np.mean(gmean), np.std(gmean)
|
416 |
-
|
417 |
-
def _plot_map(avg_map, std_map, criterion):
|
418 |
-
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(9, 3), dpi=256, width_ratios=(1, 1))
|
419 |
-
heat1 = ax1.imshow(avg_map, cmap='spring')
|
420 |
-
heat2 = ax2.imshow(std_map, cmap='spring')
|
421 |
-
ax1.set_xlim([-0.5, 5.5])
|
422 |
-
ax1.set_xticks([0, 1, 2, 3, 4, 5], ['$\lambda$=0.0', '$\lambda$=0.1', '$\lambda$=0.2', '$\lambda$=0.3', '$\lambda$=0.4', '$\lambda$=0.5'])
|
423 |
-
ax1.set_ylim([-0.5, 3.5])
|
424 |
-
ax1.set_yticks([0, 1, 2, 3], ['m=5', 'm=4', 'm=3', 'm=2'])
|
425 |
-
ax1.set_title('Average')
|
426 |
-
for x, y in product([0, 1, 2, 3, 4, 5], [0, 1, 2, 3]):
|
427 |
-
v = avg_map[y, x]
|
428 |
-
s = '%.4f' % v
|
429 |
-
if v >= 1000: s = s[:4]
|
430 |
-
elif v >= 1: s = s[:5]
|
431 |
-
else: s = s[1:6]
|
432 |
-
ax1.text(x, y, s, va='center', ha='center')
|
433 |
-
plt.colorbar(heat1, ax=ax1, shrink=0.9)
|
434 |
-
ax2.set_xlim([-0.5, 5.5])
|
435 |
-
ax2.set_xticks([0, 1, 2, 3, 4, 5], ['$\lambda$=0.0', '$\lambda$=0.1', '$\lambda$=0.2', '$\lambda$=0.3', '$\lambda$=0.4', '$\lambda$=0.5'])
|
436 |
-
ax2.set_ylim([-0.5, 3.5])
|
437 |
-
ax2.set_yticks([0, 1, 2, 3], ['m=5', 'm=4', 'm=3', 'm=2'])
|
438 |
-
for x, y in product([0, 1, 2, 3, 4, 5], [0, 1, 2, 3]):
|
439 |
-
v = std_map[y, x]
|
440 |
-
s = '%.4f' % v
|
441 |
-
if v >= 1000: s = s[:4]
|
442 |
-
elif v >= 1: s = s[:5]
|
443 |
-
else: s = s[1:6]
|
444 |
-
ax2.text(x, y, s, va='center', ha='center')
|
445 |
-
ax2.set_title('Standard Deviation')
|
446 |
-
plt.colorbar(heat2, ax=ax2, shrink=0.9)
|
447 |
-
|
448 |
-
fig.suptitle(f'{name}: {criterion}', fontsize=14)
|
449 |
-
plt.tight_layout()
|
450 |
-
# plt.show()
|
451 |
-
plt.savefig(getpath(f'results/heat/{name}-{criterion}.png'))
|
452 |
-
|
453 |
-
r_mean_map, r_std_map, d_mean_map, d_std_map, g_mean_map, g_std_map \
|
454 |
-
= (np.zeros([4, 6], dtype=float) for _ in range(6))
|
455 |
-
ms = [2, 3, 4, 5]
|
456 |
-
ls = ['0.0', '0.1', '0.2', '0.3', '0.4', '0.5']
|
457 |
-
for i, j in product(range(4), range(6)):
|
458 |
-
r_mean, r_std, d_mean, d_std, g_mean, g_std = _get_score(ms[i], ls[j])
|
459 |
-
r_mean_map[i, j] = r_mean
|
460 |
-
r_std_map[i, j] = r_std
|
461 |
-
d_mean_map[i, j] = d_mean
|
462 |
-
d_std_map[i, j] = d_std
|
463 |
-
g_mean_map[i, j] = g_mean
|
464 |
-
g_std_map[i, j] = g_std
|
465 |
-
|
466 |
-
_plot_map(r_mean_map, r_std_map, 'Reward')
|
467 |
-
_plot_map(d_mean_map, d_std_map, 'Diversity')
|
468 |
-
_plot_map(g_mean_map, g_std_map,'G-mean')
|
469 |
-
# _plot_map(g_mean_map, g_std_map,'G-mean')
|
470 |
-
|
471 |
-
def vis_samples():
|
472 |
-
# for l, m in product(['0.0', '0.1', '0.2', '0.3', '0.4', '0.5'], [2, 3, 4, 5]):
|
473 |
-
# for i in range(1, 6):
|
474 |
-
# lvls = load_batch(f'{PRJROOT}/test_data/varpm-fhp/l{l}_m{m}/t{i}/samples.lvls')
|
475 |
-
# imgs = [lvl.to_img(save_path=None) for lvl in lvls[:10]]
|
476 |
-
# make_img_sheet(imgs, 1, save_path=f'{PRJROOT}/test_data/varpm-fhp/l{l}_m{m}/t{i}/samples.png')
|
477 |
-
# for algo in ['sac', 'egsac', 'asyncsac', 'dvd', 'sunrise', 'pmoe']:
|
478 |
-
# for i in range(1, 6):
|
479 |
-
# lvls = load_batch(f'{PRJROOT}/test_data/{algo}/fhp/t{i}/samples.lvls')
|
480 |
-
# imgs = [lvl.to_img(save_path=None) for lvl in lvls[:10]]
|
481 |
-
# make_img_sheet(imgs, 1, save_path=f'{PRJROOT}/test_data/{algo}/fhp/t{i}/samples.png')
|
482 |
-
for i in range(1, 6):
|
483 |
-
lvls = load_batch(f'{PRJROOT}/test_data/DDPM-fhp/t{i}/samples.lvls')
|
484 |
-
imgs = [lvl.to_img(save_path=None) for lvl in lvls[:10]]
|
485 |
-
make_img_sheet(imgs, 1, save_path=f'{PRJROOT}/test_data/DDPM-fhp/t{i}/samples.png')
|
486 |
-
pass
|
487 |
-
pass
|
488 |
-
|
489 |
-
def make_tsne(task, title, n=500, save_path=None):
|
490 |
-
if not os.path.exists(getpath('test_data', f'samples_dist-{task}_{n}.npy')):
|
491 |
-
samples = []
|
492 |
-
for algo in ['dvd', 'egsac', 'pmoe', 'sunrise', 'asyncsac', 'sac']:
|
493 |
-
for t in range(5):
|
494 |
-
lvls = load_batch(getpath('test_data', algo, task, f't{t+1}', 'samples.lvls'))
|
495 |
-
samples += lvls[:n]
|
496 |
-
for l in ['0.0', '0.1', '0.2', '0.3', '0.4', '0.5']:
|
497 |
-
for t in range(5):
|
498 |
-
lvls = load_batch(getpath('test_data', f'varpm-{task}', f'l{l}_m5', f't{t+1}', 'samples.lvls'))
|
499 |
-
samples += lvls[:n]
|
500 |
-
distmat = []
|
501 |
-
for a in samples:
|
502 |
-
dist_list = []
|
503 |
-
for b in samples:
|
504 |
-
dist_list.append(hamming_dis(a, b))
|
505 |
-
distmat.append(dist_list)
|
506 |
-
distmat = np.array(distmat)
|
507 |
-
np.save(getpath('test_data', f'samples_dist-{task}_{n}.npy'), distmat)
|
508 |
-
|
509 |
-
labels = (
|
510 |
-
'$\lambda$=0.0', '$\lambda$=0.1', '$\lambda$=0.2', '$\lambda$=0.3', '$\lambda$=0.4',
|
511 |
-
'$\lambda$=0.5', 'DvD', 'EGSAC', 'PMOE', 'SUNRISE', 'ASAC', 'SAC'
|
512 |
-
)
|
513 |
-
tsne = TSNE(learning_rate='auto', n_components=2, metric='precomputed')
|
514 |
-
print(np.load(getpath('test_data', f'samples_dist-{task}_{n}.npy')).shape)
|
515 |
-
data = np.load(getpath('test_data', f'samples_dist-{task}_{n}.npy'))
|
516 |
-
embx = np.array(tsne.fit_transform(data))
|
517 |
-
|
518 |
-
plt.style.use('seaborn-dark-palette')
|
519 |
-
plt.figure(figsize=(5, 5), dpi=384)
|
520 |
-
colors = [plt.plot([-1000, -1100], [0, 0])[0].get_color() for _ in range(6)]
|
521 |
-
for i in range(6):
|
522 |
-
x, y = embx[i*n*5:(i+1)*n*5, 0], embx[i*n*5:(i+1)*n*5, 1]
|
523 |
-
plt.scatter(x, y, s=10, label=labels[i], marker='x', c=colors[i])
|
524 |
-
for i in range(6, 12):
|
525 |
-
x, y = embx[i*n*5:(i+1)*n*5, 0], embx[i*n*5:(i+1)*n*5, 1]
|
526 |
-
plt.scatter(x, y, s=8, linewidths=0, label=labels[i], c=colors[i-6])
|
527 |
-
# plt.scatter(embx[100:200, 0], embx[100:200, 1], c=colors[1], s=12, linewidths=0, label='Killer')
|
528 |
-
# plt.scatter(embx[200:, 0], embx[200:, 1], c=colors[2], s=12, linewidths=0, label='Collector')
|
529 |
-
# for i in range(4):
|
530 |
-
# plt.text(embx[i+100, 0], embx[i+100, 1], str(i+1))
|
531 |
-
# plt.text(embx[i+200, 0], embx[i+200, 1], str(i+1))
|
532 |
-
# pass
|
533 |
-
# for emb, lb, c in zip(embs, labels,colors):
|
534 |
-
# plt.scatter(emb[:,0], emb[:,1], c=c, label=lb, alpha=0.15, linewidths=0, s=7)
|
535 |
-
|
536 |
-
# xspan = 1.08 * max(abs(embx[:, 0].max()), abs(embx[:, 0].min()))
|
537 |
-
# yspan = 1.08 * max(abs(embx[:, 1].max()), abs(embx[:, 1].min()))
|
538 |
-
|
539 |
-
xrange = [1.05 * embx[:, 0].min(), 1.05 * embx[:, 0].max()]
|
540 |
-
yrange = [1.05 * embx[:, 1].min(), 1.25 * embx[:, 1].max()]
|
541 |
-
|
542 |
-
plt.xlim(xrange)
|
543 |
-
plt.ylim(yrange)
|
544 |
-
plt.xticks([])
|
545 |
-
plt.yticks([])
|
546 |
-
# plt.legend(ncol=6, handletextpad=0.02, labelspacing=0.05, columnspacing=0.16)
|
547 |
-
# plt.xticks([-xspan, -0.5 * xspan, 0, 0.5 * xspan, xspan], [''] * 5)
|
548 |
-
# plt.yticks([-yspan, -0.5 * yspan, 0, 0.6 * yspan, yspan], [''] * 5)
|
549 |
-
plt.title(title)
|
550 |
-
plt.legend(loc='upper center', ncol=6, fontsize=9, handlelength=.5, handletextpad=0.5, columnspacing=0.3, framealpha=0.)
|
551 |
-
plt.tight_layout(pad=0.2)
|
552 |
-
|
553 |
-
if save_path:
|
554 |
-
plt.savefig(getpath(save_path))
|
555 |
-
else:
|
556 |
-
plt.show()
|
557 |
-
|
558 |
-
def _prob_fmt(p, digitals=3, threshold=0.001):
|
559 |
-
fmt = '%.' + str(digitals) + 'f'
|
560 |
-
if p < threshold:
|
561 |
-
return '$\\approx 0$'
|
562 |
-
else:
|
563 |
-
txt = '$%s$' % ((fmt % p)[1:])
|
564 |
-
if txt == '$.000$':
|
565 |
-
txt = '$1.00$'
|
566 |
-
return txt
|
567 |
-
|
568 |
-
def _g_fmt(v, digitals=4):
|
569 |
-
fmt = '%.' + str(digitals) + 'g'
|
570 |
-
txt = (fmt % v)
|
571 |
-
lack = digitals - len(txt.replace('-', '').replace('.', ''))
|
572 |
-
if lack > 0 and '.' not in txt:
|
573 |
-
txt += '.'
|
574 |
-
return txt + '0' * lack
|
575 |
-
pass
|
576 |
-
|
577 |
-
def print_selection_prob(path, h=15, runs=2):
|
578 |
-
s0 = 0
|
579 |
-
model = torch.load(getpath(f'{path}/policy.pth'), map_location='cpu')
|
580 |
-
model.requires_grad_(False)
|
581 |
-
model.to('cpu')
|
582 |
-
n = 11
|
583 |
-
# n = load_cfgs(path, 'N')
|
584 |
-
# print(model.m)
|
585 |
-
|
586 |
-
init_vec = np.load(getpath('analysis/initial_seg.npy'))[s0]
|
587 |
-
decoder = get_decoder(device='cpu')
|
588 |
-
obs_buffer = RingQueue(n)
|
589 |
-
for r in range(runs):
|
590 |
-
for _ in range(h): obs_buffer.push(np.zeros([nz]))
|
591 |
-
obs_buffer.push(init_vec)
|
592 |
-
level_latvecs = [init_vec]
|
593 |
-
probs = np.zeros([model.m, h])
|
594 |
-
# probs = []
|
595 |
-
selects = []
|
596 |
-
for t in range(h):
|
597 |
-
# probs.append([])
|
598 |
-
obs = torch.tensor(np.concatenate(obs_buffer.to_list(), axis=-1), dtype=torch.float).view([1, -1])
|
599 |
-
muss, stdss, betas = model.get_intermediate(torch.tensor(obs))
|
600 |
-
i = torch.multinomial(betas.squeeze(), 1).item()
|
601 |
-
# print(i)
|
602 |
-
mu, std = muss[0][i], stdss[0][i]
|
603 |
-
action = Normal(mu, std).rsample([1]).squeeze().numpy()
|
604 |
-
# print(action)
|
605 |
-
# print(mu)
|
606 |
-
# print(std)
|
607 |
-
# print(action.numpy())
|
608 |
-
obs_buffer.push(action)
|
609 |
-
level_latvecs.append(action)
|
610 |
-
# i = torch.multinomial(betas.squeeze(), 1).item()
|
611 |
-
# print(i)
|
612 |
-
probs[:, t] = betas.squeeze().numpy()
|
613 |
-
selects.append(i)
|
614 |
-
pass
|
615 |
-
onehots = decoder(torch.tensor(level_latvecs).view(-1, nz, 1, 1))
|
616 |
-
segs = process_onehot(onehots)
|
617 |
-
lvl = lvlhcat(segs)
|
618 |
-
lvl.to_img(f'figures/gen_process/run{r}-01.png')
|
619 |
-
txts = [[_prob_fmt(p) for p in row] for row in probs]
|
620 |
-
for t, i in enumerate(selects):
|
621 |
-
txts[i][t] = r'$\boldsymbol{%s}$' % txts[i][t][1:-1]
|
622 |
-
for i, txt in enumerate(txts):
|
623 |
-
print(f' & $\\beta_{i+1}$ &', ' & '.join(txt), r'\\')
|
624 |
-
print(r'\midrule')
|
625 |
-
|
626 |
-
pass
|
627 |
-
|
628 |
-
def calc_selection_freqs(task, n):
|
629 |
-
def _count_one_init():
|
630 |
-
counts = np.zeros([model.m])
|
631 |
-
# init_vec = np.load(getpath('analysis/initial_seg.npy'))
|
632 |
-
obs_buffer = RingQueue(n)
|
633 |
-
for _ in range(runs):
|
634 |
-
for _ in range(h): obs_buffer.push(np.zeros([len(init_vecs), nz]))
|
635 |
-
obs_buffer.push(init_vecs)
|
636 |
-
# level_latvecs = [init_vec]
|
637 |
-
for _ in range(h):
|
638 |
-
obs = np.concatenate(obs_buffer.to_list(), axis=-1)
|
639 |
-
obs = torch.tensor(obs, device='cuda:0', dtype=torch.float)
|
640 |
-
muss, stdss, betas = model.get_intermediate(obs)
|
641 |
-
selects = torch.multinomial(betas.squeeze(), 1).squeeze()
|
642 |
-
mus = muss[[*range(len(init_vecs))], selects, :]
|
643 |
-
stds = stdss[[*range(len(init_vecs))], selects, :]
|
644 |
-
actions = Normal(mus, stds).rsample().squeeze().cpu().numpy()
|
645 |
-
obs_buffer.push(actions)
|
646 |
-
for i in selects:
|
647 |
-
counts[i] = counts[i] + 1
|
648 |
-
return counts
|
649 |
-
# onehots = decoder(torch.tensor(level_latvecs).view(-1, nz, 1, 1))
|
650 |
-
pass
|
651 |
-
pass
|
652 |
-
init_vecs = np.load(getpath('analysis/initial_seg.npy'))
|
653 |
-
freqs = [[] for _ in range(30)]
|
654 |
-
start_line = 0
|
655 |
-
for l in ('0.0', '0.1', '0.2', '0.3', '0.4', '0.5'):
|
656 |
-
print(r' \midrule')
|
657 |
-
for t, m in product(range(1, 6), (2, 3, 4, 5)):
|
658 |
-
path = getpath(f'test_data/varpm-{task}/l{l}_m{m}/t{t}')
|
659 |
-
model = torch.load(getpath(f'{path}/policy.pth'), map_location='cuda:0')
|
660 |
-
model.requires_grad_(False)
|
661 |
-
freq = np.zeros([m])
|
662 |
-
# n = load_cfgs(path, 'N')
|
663 |
-
runs, h = 100, 25
|
664 |
-
freq += _count_one_init()
|
665 |
-
freq /= (len(init_vecs) * runs * h)
|
666 |
-
freq = np.sort(freq)[::-1]
|
667 |
-
i = start_line + t - 1
|
668 |
-
freqs[i] += freq.tolist()
|
669 |
-
print(freqs[i])
|
670 |
-
start_line += 5
|
671 |
-
print(freqs)
|
672 |
-
with open(getpath(f'analysis/select_freqs-{task}.json'), 'w') as f:
|
673 |
-
json.dump(freqs, f)
|
674 |
-
|
675 |
-
def print_selection_freq():
|
676 |
-
# task, n = 'lgp', 5
|
677 |
-
task, n = 'fhp', 11
|
678 |
-
if not os.path.exists(getpath(f'analysis/select_freqs-{task}.json')):
|
679 |
-
calc_selection_freqs(task, n)
|
680 |
-
with open(getpath(f'analysis/select_freqs-{task}.json'), 'r') as f:
|
681 |
-
freqs = json.load(f)
|
682 |
-
lbds = ['0.0', '0.1', '0.2', '0.3', '0.4', '0.5']
|
683 |
-
for i, row_data in enumerate(freqs):
|
684 |
-
if i % 5 == 0:
|
685 |
-
print(r' \midrule')
|
686 |
-
print(r' \multirow{5}{*}{$%s$}' % lbds[i//5])
|
687 |
-
txt = ' & '.join(map(_prob_fmt, row_data))
|
688 |
-
print(f' & {i%5+1} &', txt, r'\\')
|
689 |
-
|
690 |
-
def print_individual_performances(task):
|
691 |
-
for m, l in product((2, 3, 4, 5), ('0.0', '0.1', '0.2', '0.3', '0.4', '0.5')):
|
692 |
-
values = []
|
693 |
-
if l == '0.0':
|
694 |
-
print(r' \midrule')
|
695 |
-
print(r' \multirow{6}{*}{%d}' % m)
|
696 |
-
for t in range(1, 6):
|
697 |
-
path = f'test_data/varpm-{task}/l{l}_m{m}/t{t}/performance.csv'
|
698 |
-
reward, diversity = load_dict_json(path, 'reward', 'diversity')
|
699 |
-
values.append([reward, diversity])
|
700 |
-
values.sort(key=lambda item: -item[0])
|
701 |
-
values = [*chain(*values)]
|
702 |
-
txts = [_g_fmt(v) for v in values]
|
703 |
-
print(' &', f'${l}$ & ', ' & '.join(txts), r'\\')
|
704 |
-
pass
|
705 |
-
|
706 |
-
if __name__ == '__main__':
|
707 |
-
# print_selection_prob('test_data/varpm-fhp/l0.5_m5/t5')
|
708 |
-
# print_selection_prob('test_data/varpm-fhp/l0.1_m5/t5')
|
709 |
-
# print_selection_freq()
|
710 |
-
# print_compare_tab_nonrl()
|
711 |
-
# print_individual_performances('fhp')
|
712 |
-
# print('\n\n')
|
713 |
-
# print_individual_performances('lgp')
|
714 |
-
|
715 |
-
# plot_cmp_learning_curves('fhp', save_path='results/learning_curves/fhp.png', title='MarioPuzzle')
|
716 |
-
# plot_cmp_learning_curves('lgp', save_path='results/learning_curves/lgp.png', title='MultiFacet')
|
717 |
-
|
718 |
-
# plot_crosstest_scatters('fhp', title='MarioPuzzle')
|
719 |
-
# plot_crosstest_scatters('lgp', title='MultiFacet')
|
720 |
-
# # plot_crosstest_scatters('fhp', yrange=(0, 2500), xrange=(20, 70), title='MarioPuzzle')
|
721 |
-
# plot_crosstest_scatters('lgp', yrange=(0, 1500), xrange=(20, 50), title='MultiFacet')
|
722 |
-
# plot_crosstest_scatters('lgp', yrange=(0, 800), xrange=(44, 48), title=' ')
|
723 |
-
|
724 |
-
|
725 |
-
# plot_varpm_heat('fhp', 'MarioPuzzle')
|
726 |
-
# plot_varpm_heat('lgp', 'MultiFacet')
|
727 |
-
|
728 |
-
vis_samples()
|
729 |
-
|
730 |
-
# make_tsne('fhp', 'MarioPuzzle', n=100)
|
731 |
-
# make_tsne('lgp', 'MultiFacet', n=100)
|
732 |
-
pass
|
733 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pyproject.toml
DELETED
@@ -1,21 +0,0 @@
|
|
1 |
-
[tool.poetry]
|
2 |
-
name = "ncerl"
|
3 |
-
version = "0.1.0"
|
4 |
-
description = ""
|
5 |
-
authors = ["Ziqi Wang"]
|
6 |
-
readme = "README.md"
|
7 |
-
|
8 |
-
[tool.poetry.dependencies]
|
9 |
-
python = "^3.9"
|
10 |
-
JPype1 = "1.3.0"
|
11 |
-
dtw = "1.4.0"
|
12 |
-
torch = "1.8.1"
|
13 |
-
numpy = "^2.0.0"
|
14 |
-
pillow = "10.0.0"
|
15 |
-
matplotlib = "3.6.3"
|
16 |
-
pandas = "1.3.2"
|
17 |
-
|
18 |
-
|
19 |
-
[build-system]
|
20 |
-
requires = ["poetry-core"]
|
21 |
-
build-backend = "poetry.core.masonry.api"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
Binary files a/requirements.txt and b/requirements.txt differ
|
|
src/drl/egsac/train_egsac.py
CHANGED
@@ -71,6 +71,7 @@ def train_EGSAC(args):
|
|
71 |
return
|
72 |
device = 'cpu' if args.gpuid < 0 or not torch.cuda.is_available() else f'cuda:{args.gpuid}'
|
73 |
|
|
|
74 |
rfunc = importlib.import_module('src.env.rfuncs').__getattribute__(f'{args.rfunc_name}')()
|
75 |
with open(res_path + '/run_config.txt', 'w') as f:
|
76 |
f.write(time.strftime('%Y-%m-%d %H:%M') + '\n')
|
@@ -83,8 +84,7 @@ def train_EGSAC(args):
|
|
83 |
f.write('-' * 50 + '\n')
|
84 |
f.write(str(rfunc))
|
85 |
hist_len = rfunc.get_n()
|
86 |
-
|
87 |
-
# json.dump(hist_len, f)
|
88 |
with open(f'{res_path}/cfgs.json', 'w') as f:
|
89 |
data = {'N': hist_len, 'gamma': args.gamma, 'h': args.eplen, 'rfunc': args.rfunc_name}
|
90 |
json.dump(data, f)
|
|
|
71 |
return
|
72 |
device = 'cpu' if args.gpuid < 0 or not torch.cuda.is_available() else f'cuda:{args.gpuid}'
|
73 |
|
74 |
+
# 动态导入reward function
|
75 |
rfunc = importlib.import_module('src.env.rfuncs').__getattribute__(f'{args.rfunc_name}')()
|
76 |
with open(res_path + '/run_config.txt', 'w') as f:
|
77 |
f.write(time.strftime('%Y-%m-%d %H:%M') + '\n')
|
|
|
84 |
f.write('-' * 50 + '\n')
|
85 |
f.write(str(rfunc))
|
86 |
hist_len = rfunc.get_n()
|
87 |
+
|
|
|
88 |
with open(f'{res_path}/cfgs.json', 'w') as f:
|
89 |
data = {'N': hist_len, 'gamma': args.gamma, 'h': args.eplen, 'rfunc': args.rfunc_name}
|
90 |
json.dump(data, f)
|
src/drl/sunrise/train_sunrise.py
CHANGED
@@ -141,7 +141,6 @@ def get_trainer(args, obs_dim, action_dim, path, device):
|
|
141 |
|
142 |
|
143 |
trainer = NeurIPS20SACEnsembleTrainer(
|
144 |
-
# env=eval_env,
|
145 |
policy=L_policy,
|
146 |
qf1=L_qf1,
|
147 |
qf2=L_qf2,
|
@@ -159,7 +158,6 @@ def get_trainer(args, obs_dim, action_dim, path, device):
|
|
159 |
**variant['trainer_kwargs']
|
160 |
)
|
161 |
return trainer
|
162 |
-
pass
|
163 |
|
164 |
def get_algo(args, rfunc, device, path):
|
165 |
algorithm = AsyncOffPolicyALgo(
|
@@ -178,6 +176,7 @@ def get_algo(args, rfunc, device, path):
|
|
178 |
return algorithm
|
179 |
|
180 |
def train_SUNRISE(args):
|
|
|
181 |
if not args.path:
|
182 |
path = auto_dire('training_data', args.name)
|
183 |
else:
|
@@ -192,6 +191,8 @@ def train_SUNRISE(args):
|
|
192 |
print(f'Trainning at <{path}> is skipped as there has a finished trial already.')
|
193 |
return
|
194 |
device = 'cpu' if args.gpuid < 0 or not torch.cuda.is_available() else f'cuda:{args.gpuid}'
|
|
|
|
|
195 |
rfunc = importlib.import_module('src.env.rfuncs').__getattribute__(f'{args.rfunc}')()
|
196 |
|
197 |
with open(path + '/run_configuration.txt', 'w') as f:
|
@@ -217,9 +218,3 @@ def train_SUNRISE(args):
|
|
217 |
trainer = get_trainer(args, obs_dim, action_dim, path, device)
|
218 |
algorithm = get_algo(args, rfunc, device, path)
|
219 |
_, timecost = record_time(algorithm.train)(env, trainer, args.budget, args.inference_type, path)
|
220 |
-
|
221 |
-
pass
|
222 |
-
|
223 |
-
|
224 |
-
if __name__ == '__main__':
|
225 |
-
pass
|
|
|
141 |
|
142 |
|
143 |
trainer = NeurIPS20SACEnsembleTrainer(
|
|
|
144 |
policy=L_policy,
|
145 |
qf1=L_qf1,
|
146 |
qf2=L_qf2,
|
|
|
158 |
**variant['trainer_kwargs']
|
159 |
)
|
160 |
return trainer
|
|
|
161 |
|
162 |
def get_algo(args, rfunc, device, path):
|
163 |
algorithm = AsyncOffPolicyALgo(
|
|
|
176 |
return algorithm
|
177 |
|
178 |
def train_SUNRISE(args):
|
179 |
+
# 创建目录
|
180 |
if not args.path:
|
181 |
path = auto_dire('training_data', args.name)
|
182 |
else:
|
|
|
191 |
print(f'Trainning at <{path}> is skipped as there has a finished trial already.')
|
192 |
return
|
193 |
device = 'cpu' if args.gpuid < 0 or not torch.cuda.is_available() else f'cuda:{args.gpuid}'
|
194 |
+
|
195 |
+
# 导入reward function
|
196 |
rfunc = importlib.import_module('src.env.rfuncs').__getattribute__(f'{args.rfunc}')()
|
197 |
|
198 |
with open(path + '/run_configuration.txt', 'w') as f:
|
|
|
218 |
trainer = get_trainer(args, obs_dim, action_dim, path, device)
|
219 |
algorithm = get_algo(args, rfunc, device, path)
|
220 |
_, timecost = record_time(algorithm.train)(env, trainer, args.budget, args.inference_type, path)
|
|
|
|
|
|
|
|
|
|
|
|
src/drl/train_async.py
CHANGED
@@ -36,6 +36,9 @@ def set_common_args(parser):
|
|
36 |
)
|
37 |
|
38 |
def drl_train(foo):
|
|
|
|
|
|
|
39 |
def __inner(args):
|
40 |
if not args.path:
|
41 |
path = auto_dire('training_data', args.name)
|
@@ -74,6 +77,7 @@ def drl_train(foo):
|
|
74 |
json.dump(data, f)
|
75 |
obs_dim, act_dim = env.histlen * nz, nz
|
76 |
|
|
|
77 |
agent = foo(args, path, device, obs_dim, act_dim)
|
78 |
|
79 |
agent.to(device)
|
@@ -89,6 +93,7 @@ def set_AsyncSAC_parser(parser):
|
|
89 |
set_common_args(parser)
|
90 |
parser.add_argument('--name', type=str, default='AsyncSAC', help='Name of this algorithm.')
|
91 |
|
|
|
92 |
@drl_train
|
93 |
def train_AsyncSAC(args, path, device, obs_dim, act_dim):
|
94 |
actor = SoftActor(
|
@@ -116,12 +121,16 @@ def set_NCESAC_parser(parser):
|
|
116 |
@drl_train
|
117 |
def train_NCESAC(args, path, device, obs_dim, act_dim):
|
118 |
me_reg, actor_nn_constructor = None, None
|
|
|
|
|
119 |
if args.me_type == 'log':
|
120 |
me_reg = LogWassersteinExclusion(args.lbd)
|
121 |
elif args.me_type == 'clip':
|
122 |
me_reg = ClipExclusion(args.lbd)
|
123 |
elif args.me_type == 'logclip':
|
124 |
me_reg = LogClipExclusion(args.lbd)
|
|
|
|
|
125 |
if args.actor_net_type == 'conv':
|
126 |
actor_nn_constructor = lambda: EsmbGaussianConv(
|
127 |
obs_dim, act_dim, args.actor_hiddens, args.actor_hiddens, args.m
|
@@ -130,7 +139,11 @@ def train_NCESAC(args, path, device, obs_dim, act_dim):
|
|
130 |
actor_nn_constructor = lambda: EsmbGaussianMLP(
|
131 |
obs_dim, act_dim, args.actor_hiddens, args.actor_hiddens, args.m
|
132 |
)
|
|
|
|
|
133 |
actor = MERegMixSoftActor(actor_nn_constructor, me_reg, tar_ent=args.tar_entropy)
|
|
|
|
|
134 |
critic = MERegSoftDoubleClipCriticQ(
|
135 |
lambda : ObsActMLP(obs_dim, act_dim, args.critic_hiddens),
|
136 |
gamma=args.gamma, tau=args.tau
|
@@ -139,6 +152,8 @@ def train_NCESAC(args, path, device, obs_dim, act_dim):
|
|
139 |
lambda : ObsActMLP(obs_dim, act_dim, args.critic_hiddens),
|
140 |
gamma=args.gamma, tau=args.tau
|
141 |
)
|
|
|
|
|
142 |
with open(f'{path}/nn_architecture.txt', 'w') as f:
|
143 |
f.writelines([
|
144 |
'-' * 24 + 'Actor' + '-' * 24 + '\n', actor.get_nn_arch_str(),
|
|
|
36 |
)
|
37 |
|
38 |
def drl_train(foo):
|
39 |
+
"""
|
40 |
+
DRL Train, foo是被调用的函数, 如train_AsyncSAC.
|
41 |
+
"""
|
42 |
def __inner(args):
|
43 |
if not args.path:
|
44 |
path = auto_dire('training_data', args.name)
|
|
|
77 |
json.dump(data, f)
|
78 |
obs_dim, act_dim = env.histlen * nz, nz
|
79 |
|
80 |
+
# 根据foo的不同返回agent, 返回的类型是ActCrtAgent
|
81 |
agent = foo(args, path, device, obs_dim, act_dim)
|
82 |
|
83 |
agent.to(device)
|
|
|
93 |
set_common_args(parser)
|
94 |
parser.add_argument('--name', type=str, default='AsyncSAC', help='Name of this algorithm.')
|
95 |
|
96 |
+
#同样的sac训练,但是多了异步
|
97 |
@drl_train
|
98 |
def train_AsyncSAC(args, path, device, obs_dim, act_dim):
|
99 |
actor = SoftActor(
|
|
|
121 |
@drl_train
|
122 |
def train_NCESAC(args, path, device, obs_dim, act_dim):
|
123 |
me_reg, actor_nn_constructor = None, None
|
124 |
+
|
125 |
+
# 初始化不同的正则化器
|
126 |
if args.me_type == 'log':
|
127 |
me_reg = LogWassersteinExclusion(args.lbd)
|
128 |
elif args.me_type == 'clip':
|
129 |
me_reg = ClipExclusion(args.lbd)
|
130 |
elif args.me_type == 'logclip':
|
131 |
me_reg = LogClipExclusion(args.lbd)
|
132 |
+
|
133 |
+
# 初始化不同的 网络构造器
|
134 |
if args.actor_net_type == 'conv':
|
135 |
actor_nn_constructor = lambda: EsmbGaussianConv(
|
136 |
obs_dim, act_dim, args.actor_hiddens, args.actor_hiddens, args.m
|
|
|
139 |
actor_nn_constructor = lambda: EsmbGaussianMLP(
|
140 |
obs_dim, act_dim, args.actor_hiddens, args.actor_hiddens, args.m
|
141 |
)
|
142 |
+
|
143 |
+
# 初始化Actor
|
144 |
actor = MERegMixSoftActor(actor_nn_constructor, me_reg, tar_ent=args.tar_entropy)
|
145 |
+
|
146 |
+
# 初始化Critic
|
147 |
critic = MERegSoftDoubleClipCriticQ(
|
148 |
lambda : ObsActMLP(obs_dim, act_dim, args.critic_hiddens),
|
149 |
gamma=args.gamma, tau=args.tau
|
|
|
152 |
lambda : ObsActMLP(obs_dim, act_dim, args.critic_hiddens),
|
153 |
gamma=args.gamma, tau=args.tau
|
154 |
)
|
155 |
+
|
156 |
+
# 保存神经网络架构
|
157 |
with open(f'{path}/nn_architecture.txt', 'w') as f:
|
158 |
f.writelines([
|
159 |
'-' * 24 + 'Actor' + '-' * 24 + '\n', actor.get_nn_arch_str(),
|
src/drl/train_sinproc.py
CHANGED
@@ -30,24 +30,30 @@ def set_common_args(parser):
|
|
30 |
|
31 |
def drl_train(foo):
|
32 |
def __inner(args):
|
|
|
33 |
if not args.path:
|
34 |
path = auto_dire('training_data', args.name)
|
35 |
else:
|
36 |
path = getpath('training_data', args.path)
|
37 |
os.makedirs(path, exist_ok=True)
|
|
|
38 |
if os.path.exists(f'{path}/policy.pth'):
|
39 |
print(f'Trainning at <{path}> is skipped as there has a finished trial already.')
|
40 |
return
|
41 |
device = 'cpu' if args.gpuid < 0 or not torch.cuda.is_available() else f'cuda:{args.gpuid}'
|
42 |
|
|
|
43 |
rfunc = importlib.import_module('src.env.rfuncs').__getattribute__(f'{args.rfunc}')()
|
44 |
env = SingleProcessOLGenEnv(rfunc, get_decoder('models/decoder.pth'), args.eplen, device=device)
|
|
|
|
|
45 |
loggers = [
|
46 |
AsyncCsvLogger(f'{path}/log.csv', rfunc),
|
47 |
AsyncStdLogger(rfunc, 2000, f'{path}/log.txt' if args.redirect else '')
|
48 |
]
|
49 |
if args.periodic_gen_num > 0:
|
50 |
loggers.append(GenResLogger(path, args.periodic_gen_num, args.gen_period))
|
|
|
51 |
with open(path + '/run_configuration.txt', 'w') as f:
|
52 |
f.write(time.strftime('%Y-%m-%d %H:%M') + '\n')
|
53 |
f.write(f'---------{args.name}---------\n')
|
@@ -59,14 +65,18 @@ def drl_train(foo):
|
|
59 |
f.write('-' * 50 + '\n')
|
60 |
f.write(str(rfunc))
|
61 |
N = rfunc.get_n()
|
|
|
62 |
with open(f'{path}/cfgs.json', 'w') as f:
|
63 |
data = {'N': N, 'gamma': args.gamma, 'h': args.eplen, 'rfunc': args.rfunc}
|
64 |
json.dump(data, f)
|
|
|
65 |
obs_dim, act_dim = env.hist_len * nz, nz
|
66 |
|
|
|
67 |
agent = foo(args, path, device, obs_dim, act_dim)
|
68 |
|
69 |
agent.to(device)
|
|
|
70 |
trainer = SinProcOffpolicyTrainer(
|
71 |
ReplayMem(args.mem_size, device=device), update_per=args.update_per, batch=args.batch
|
72 |
)
|
@@ -76,10 +86,12 @@ def drl_train(foo):
|
|
76 |
return __inner
|
77 |
|
78 |
############### SAC ###############
|
|
|
79 |
def set_SAC_parser(parser):
|
80 |
set_common_args(parser)
|
81 |
parser.add_argument('--name', type=str, default='SAC', help='Name of this algorithm.')
|
82 |
|
|
|
83 |
@drl_train
|
84 |
def train_SAC(args, path, device, obs_dim, act_dim):
|
85 |
actor = SoftActor(
|
@@ -88,6 +100,7 @@ def train_SAC(args, path, device, obs_dim, act_dim):
|
|
88 |
critic = SoftDoubleClipCriticQ(
|
89 |
lambda : ObsActMLP(obs_dim, act_dim, args.critic_hiddens), gamma=args.gamma, tau=args.tau
|
90 |
)
|
|
|
91 |
with open(f'{path}/nn_architecture.txt', 'w') as f:
|
92 |
f.writelines([
|
93 |
'-' * 24 + 'Actor' + '-' * 24 + '\n', actor.get_nn_arch_str(),
|
|
|
30 |
|
31 |
def drl_train(foo):
|
32 |
def __inner(args):
|
33 |
+
# 设置保存路径
|
34 |
if not args.path:
|
35 |
path = auto_dire('training_data', args.name)
|
36 |
else:
|
37 |
path = getpath('training_data', args.path)
|
38 |
os.makedirs(path, exist_ok=True)
|
39 |
+
# 检查是否已经存在训练完成的模型
|
40 |
if os.path.exists(f'{path}/policy.pth'):
|
41 |
print(f'Trainning at <{path}> is skipped as there has a finished trial already.')
|
42 |
return
|
43 |
device = 'cpu' if args.gpuid < 0 or not torch.cuda.is_available() else f'cuda:{args.gpuid}'
|
44 |
|
45 |
+
# 导入reward函数
|
46 |
rfunc = importlib.import_module('src.env.rfuncs').__getattribute__(f'{args.rfunc}')()
|
47 |
env = SingleProcessOLGenEnv(rfunc, get_decoder('models/decoder.pth'), args.eplen, device=device)
|
48 |
+
|
49 |
+
# 设置日志记录器
|
50 |
loggers = [
|
51 |
AsyncCsvLogger(f'{path}/log.csv', rfunc),
|
52 |
AsyncStdLogger(rfunc, 2000, f'{path}/log.txt' if args.redirect else '')
|
53 |
]
|
54 |
if args.periodic_gen_num > 0:
|
55 |
loggers.append(GenResLogger(path, args.periodic_gen_num, args.gen_period))
|
56 |
+
# 保存运行配置
|
57 |
with open(path + '/run_configuration.txt', 'w') as f:
|
58 |
f.write(time.strftime('%Y-%m-%d %H:%M') + '\n')
|
59 |
f.write(f'---------{args.name}---------\n')
|
|
|
65 |
f.write('-' * 50 + '\n')
|
66 |
f.write(str(rfunc))
|
67 |
N = rfunc.get_n()
|
68 |
+
# 保存配置文件
|
69 |
with open(f'{path}/cfgs.json', 'w') as f:
|
70 |
data = {'N': N, 'gamma': args.gamma, 'h': args.eplen, 'rfunc': args.rfunc}
|
71 |
json.dump(data, f)
|
72 |
+
# 设置观察和动作维度
|
73 |
obs_dim, act_dim = env.hist_len * nz, nz
|
74 |
|
75 |
+
# 创建代理
|
76 |
agent = foo(args, path, device, obs_dim, act_dim)
|
77 |
|
78 |
agent.to(device)
|
79 |
+
# 创建训练器
|
80 |
trainer = SinProcOffpolicyTrainer(
|
81 |
ReplayMem(args.mem_size, device=device), update_per=args.update_per, batch=args.batch
|
82 |
)
|
|
|
86 |
return __inner
|
87 |
|
88 |
############### SAC ###############
|
89 |
+
# 设置SAC参数的函数
|
90 |
def set_SAC_parser(parser):
|
91 |
set_common_args(parser)
|
92 |
parser.add_argument('--name', type=str, default='SAC', help='Name of this algorithm.')
|
93 |
|
94 |
+
# SAC训练函数
|
95 |
@drl_train
|
96 |
def train_SAC(args, path, device, obs_dim, act_dim):
|
97 |
actor = SoftActor(
|
|
|
100 |
critic = SoftDoubleClipCriticQ(
|
101 |
lambda : ObsActMLP(obs_dim, act_dim, args.critic_hiddens), gamma=args.gamma, tau=args.tau
|
102 |
)
|
103 |
+
# 保存神经网络架构
|
104 |
with open(f'{path}/nn_architecture.txt', 'w') as f:
|
105 |
f.writelines([
|
106 |
'-' * 24 + 'Actor' + '-' * 24 + '\n', actor.get_nn_arch_str(),
|
src/env/environments.py
CHANGED
@@ -43,7 +43,7 @@ class SingleProcessOLGenEnv(gym.Env):
|
|
43 |
self.device = device
|
44 |
self.action_space = gym.spaces.Box(-1, 1, (nz,))
|
45 |
self.observation_space = gym.spaces.Box(-1, 1, (self.hist_len * nz,))
|
46 |
-
|
47 |
self.lat_vecs = []
|
48 |
self.simulator = MarioProxy()
|
49 |
pass
|
@@ -62,7 +62,6 @@ class SingleProcessOLGenEnv(gym.Env):
|
|
62 |
|
63 |
def __evalute(self):
|
64 |
z = torch.tensor(np.stack(self.lat_vecs).reshape([-1, nz, 1, 1]), device=self.device, dtype=torch.float)
|
65 |
-
# print(z.shape)
|
66 |
segs = process_onehot(self.decoder(z))
|
67 |
lvl = lvlhcat(segs)
|
68 |
simlt_res = MarioProxy.get_seg_infos(self.simulator.simulate_complete(lvl))
|
@@ -123,6 +122,7 @@ class AsyncOlGenEnv:
|
|
123 |
self.decoder = decoder
|
124 |
self.decoder.to(device)
|
125 |
self.device = device
|
|
|
126 |
self.eval_pool = eval_pool
|
127 |
self.eplen = eplen
|
128 |
self.tid = 0
|
@@ -250,7 +250,6 @@ class SyncOLGenWorkerEnv(gym.Env):
|
|
250 |
done = self.counter >= self.eplen
|
251 |
if done:
|
252 |
full_level = lvlhcat(self.segs)
|
253 |
-
# full_level = self.repairer.repair(full_level)
|
254 |
w = MarioLevel.seg_width
|
255 |
segs = [full_level[:, s: s + w] for s in range(0, full_level.w, w)]
|
256 |
if self.mario_proxy:
|
@@ -352,14 +351,12 @@ class VecOLGenEnv(SubprocVecEnv):
|
|
352 |
target_remotes = self._get_target_remotes(env_ids)
|
353 |
|
354 |
n_inits = 1 if self.init_one else self.hist_len
|
355 |
-
# latvecs = [sample_latvec(n_inits, tensor=False) for _ in range(len(env_ids))]
|
356 |
|
357 |
latvecs = [self.latvec_set[random.sample(range(len(self.latvec_set)), n_inits)] for _ in range(len(env_ids))]
|
358 |
with torch.no_grad():
|
359 |
segss = [[] for _ in range(len(env_ids))]
|
360 |
for i in range(len(env_ids)):
|
361 |
z = torch.tensor(latvecs[i]).view(-1, nz, 1, 1).to(self.device)
|
362 |
-
# print(self.decoder(z).shape)
|
363 |
segss[i] = [process_onehot(self.decoder(z))] if self.init_one else process_onehot(self.decoder(z))
|
364 |
for remote, latvec, segs in zip(target_remotes, latvecs, segss):
|
365 |
kwargs = {'backup_latvecs': latvec, 'backup_strsegs': [str(seg) for seg in segs]}
|
|
|
43 |
self.device = device
|
44 |
self.action_space = gym.spaces.Box(-1, 1, (nz,))
|
45 |
self.observation_space = gym.spaces.Box(-1, 1, (self.hist_len * nz,))
|
46 |
+
|
47 |
self.lat_vecs = []
|
48 |
self.simulator = MarioProxy()
|
49 |
pass
|
|
|
62 |
|
63 |
def __evalute(self):
|
64 |
z = torch.tensor(np.stack(self.lat_vecs).reshape([-1, nz, 1, 1]), device=self.device, dtype=torch.float)
|
|
|
65 |
segs = process_onehot(self.decoder(z))
|
66 |
lvl = lvlhcat(segs)
|
67 |
simlt_res = MarioProxy.get_seg_infos(self.simulator.simulate_complete(lvl))
|
|
|
122 |
self.decoder = decoder
|
123 |
self.decoder.to(device)
|
124 |
self.device = device
|
125 |
+
# mario simulator 在eval_pool里面
|
126 |
self.eval_pool = eval_pool
|
127 |
self.eplen = eplen
|
128 |
self.tid = 0
|
|
|
250 |
done = self.counter >= self.eplen
|
251 |
if done:
|
252 |
full_level = lvlhcat(self.segs)
|
|
|
253 |
w = MarioLevel.seg_width
|
254 |
segs = [full_level[:, s: s + w] for s in range(0, full_level.w, w)]
|
255 |
if self.mario_proxy:
|
|
|
351 |
target_remotes = self._get_target_remotes(env_ids)
|
352 |
|
353 |
n_inits = 1 if self.init_one else self.hist_len
|
|
|
354 |
|
355 |
latvecs = [self.latvec_set[random.sample(range(len(self.latvec_set)), n_inits)] for _ in range(len(env_ids))]
|
356 |
with torch.no_grad():
|
357 |
segss = [[] for _ in range(len(env_ids))]
|
358 |
for i in range(len(env_ids)):
|
359 |
z = torch.tensor(latvecs[i]).view(-1, nz, 1, 1).to(self.device)
|
|
|
360 |
segss[i] = [process_onehot(self.decoder(z))] if self.init_one else process_onehot(self.decoder(z))
|
361 |
for remote, latvec, segs in zip(target_remotes, latvecs, segss):
|
362 |
kwargs = {'backup_latvecs': latvec, 'backup_strsegs': [str(seg) for seg in segs]}
|
src/env/rfunc.py
CHANGED
@@ -44,6 +44,9 @@ class RewardTerm:
|
|
44 |
|
45 |
|
46 |
class Playability(RewardTerm):
|
|
|
|
|
|
|
47 |
def __init__(self, magnitude=1):
|
48 |
super(Playability, self).__init__(True)
|
49 |
self.magnitude=magnitude
|
@@ -57,6 +60,9 @@ class Playability(RewardTerm):
|
|
57 |
|
58 |
|
59 |
class MeanDivergenceFun(RewardTerm):
|
|
|
|
|
|
|
60 |
def __init__(self, goal_div, n=defaults['n'], s=8):
|
61 |
super().__init__(False)
|
62 |
self.l = goal_div * 0.26 / 0.6
|
@@ -74,7 +80,6 @@ class MeanDivergenceFun(RewardTerm):
|
|
74 |
divergences = []
|
75 |
while k * self.s <= (min(self.n, i) - 1) * MarioLevel.seg_width:
|
76 |
cmp_seg = histroy[:, k * self.s: k * self.s + MarioLevel.seg_width]
|
77 |
-
# print(i, nd, cmp_seg.shape)
|
78 |
divergences.append(tile_pattern_js_div(seg, cmp_seg))
|
79 |
k += 1
|
80 |
mean_d = sum(divergences) / len(divergences)
|
@@ -211,9 +216,5 @@ class HistoricalDeviation(RewardTerm):
|
|
211 |
|
212 |
|
213 |
if __name__ == '__main__':
|
214 |
-
# print(type(ceil(0.2)))
|
215 |
-
# arr = [1., 3., 2.]
|
216 |
-
# arr.sort()
|
217 |
-
# print(arr)
|
218 |
rfunc = HistoricalDeviation()
|
219 |
|
|
|
44 |
|
45 |
|
46 |
class Playability(RewardTerm):
|
47 |
+
"""
|
48 |
+
可玩性
|
49 |
+
"""
|
50 |
def __init__(self, magnitude=1):
|
51 |
super(Playability, self).__init__(True)
|
52 |
self.magnitude=magnitude
|
|
|
60 |
|
61 |
|
62 |
class MeanDivergenceFun(RewardTerm):
|
63 |
+
"""
|
64 |
+
多样性
|
65 |
+
"""
|
66 |
def __init__(self, goal_div, n=defaults['n'], s=8):
|
67 |
super().__init__(False)
|
68 |
self.l = goal_div * 0.26 / 0.6
|
|
|
80 |
divergences = []
|
81 |
while k * self.s <= (min(self.n, i) - 1) * MarioLevel.seg_width:
|
82 |
cmp_seg = histroy[:, k * self.s: k * self.s + MarioLevel.seg_width]
|
|
|
83 |
divergences.append(tile_pattern_js_div(seg, cmp_seg))
|
84 |
k += 1
|
85 |
mean_d = sum(divergences) / len(divergences)
|
|
|
216 |
|
217 |
|
218 |
if __name__ == '__main__':
|
|
|
|
|
|
|
|
|
219 |
rfunc = HistoricalDeviation()
|
220 |
|
src/gan/adversarial_train.py
CHANGED
@@ -87,7 +87,6 @@ def train_GAN(args):
|
|
87 |
w = csv.writer(f)
|
88 |
w.writerow(['key', 'value', ''])
|
89 |
w.writerows(list(cfgs.items()))
|
90 |
-
# pds.DataFrame.from_dict(cfgs, orient='index', columns=['value']).to_csv(f'{path_}/cfgs.csv')
|
91 |
|
92 |
start_time = time.time()
|
93 |
log_target = open(f'{res_path}/logs.csv', 'w')
|
@@ -120,26 +119,6 @@ def train_GAN(args):
|
|
120 |
loss_G = -netD(fake).mean()
|
121 |
loss_G.backward()
|
122 |
optG.step()
|
123 |
-
# # Evaluate
|
124 |
-
# if t % args.eval_itv == (args.eval_itv - 1):
|
125 |
-
# netG.eval()
|
126 |
-
# netD.eval()
|
127 |
-
# with torch.no_grad():
|
128 |
-
# real = torch.stack(data[:min(100, len(data))])
|
129 |
-
# z = sample_latvec(100, device=device, distribuion=args.noise)
|
130 |
-
# fake = netG(z)
|
131 |
-
# y_real = netD(real).mean().item()
|
132 |
-
# y_fake = netD(fake).mean().item()
|
133 |
-
# # hamming_divs, tpjs_divs = evaluate_diversity(process_onehot(fake))
|
134 |
-
#
|
135 |
-
# # items = (t+1, y_real, y_fake, hamming_divs, tpjs_divs, time.time() - start_time)
|
136 |
-
# # log_writer.writerow(items)
|
137 |
-
# print(
|
138 |
-
# 'Iteration %d, y-real=%.3g, y-fake=%.3g, Hamming-divs: %.5g, TPJS-divs: %.5g, '
|
139 |
-
# 'time: %.1fs' % items
|
140 |
-
# )
|
141 |
-
# netD.train()
|
142 |
-
# netG.train()
|
143 |
if t % args.save_itv == (args.save_itv - 1):
|
144 |
netG.eval()
|
145 |
netD.eval()
|
|
|
87 |
w = csv.writer(f)
|
88 |
w.writerow(['key', 'value', ''])
|
89 |
w.writerows(list(cfgs.items()))
|
|
|
90 |
|
91 |
start_time = time.time()
|
92 |
log_target = open(f'{res_path}/logs.csv', 'w')
|
|
|
119 |
loss_G = -netD(fake).mean()
|
120 |
loss_G.backward()
|
121 |
optG.step()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
if t % args.save_itv == (args.save_itv - 1):
|
123 |
netG.eval()
|
124 |
netD.eval()
|
src/gan/gankits.py
CHANGED
@@ -3,7 +3,7 @@ from src.smb.level import MarioLevel
|
|
3 |
from src.gan.gans import nz
|
4 |
from src.utils.filesys import getpath
|
5 |
|
6 |
-
|
7 |
def sample_latvec(n=1, device='cpu', distribuion='uniform'):
|
8 |
if distribuion == 'uniform':
|
9 |
return torch.rand(n, nz, 1, 1, device=device) * 2 - 1
|
@@ -12,6 +12,7 @@ def sample_latvec(n=1, device='cpu', distribuion='uniform'):
|
|
12 |
else:
|
13 |
raise TypeError(f'unknow noise distribution: {distribuion}')
|
14 |
|
|
|
15 |
def process_onehot(raw_tensor_onehot):
|
16 |
H, W = MarioLevel.height, MarioLevel.seg_width
|
17 |
res = []
|
@@ -26,6 +27,5 @@ def get_decoder(path='models/decoder.pth', device='cpu'):
|
|
26 |
decoder.requires_grad_(False)
|
27 |
decoder.eval()
|
28 |
return decoder
|
29 |
-
pass
|
30 |
|
31 |
|
|
|
3 |
from src.gan.gans import nz
|
4 |
from src.utils.filesys import getpath
|
5 |
|
6 |
+
# 采样噪声
|
7 |
def sample_latvec(n=1, device='cpu', distribuion='uniform'):
|
8 |
if distribuion == 'uniform':
|
9 |
return torch.rand(n, nz, 1, 1, device=device) * 2 - 1
|
|
|
12 |
else:
|
13 |
raise TypeError(f'unknow noise distribution: {distribuion}')
|
14 |
|
15 |
+
# 处理onehot数组
|
16 |
def process_onehot(raw_tensor_onehot):
|
17 |
H, W = MarioLevel.height, MarioLevel.seg_width
|
18 |
res = []
|
|
|
27 |
decoder.requires_grad_(False)
|
28 |
decoder.eval()
|
29 |
return decoder
|
|
|
30 |
|
31 |
|
src/gan/gans.py
CHANGED
@@ -5,7 +5,7 @@ from src.utils.dl import SelfAttn
|
|
5 |
|
6 |
nz = 20
|
7 |
|
8 |
-
|
9 |
class SAGenerator(nn.Module):
|
10 |
def __init__(self, base_channels=32):
|
11 |
super(SAGenerator, self).__init__()
|
|
|
5 |
|
6 |
nz = 20
|
7 |
|
8 |
+
# Self Attention GAN
|
9 |
class SAGenerator(nn.Module):
|
10 |
def __init__(self, base_channels=32):
|
11 |
super(SAGenerator, self).__init__()
|
src/olgen/olg_policy.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import glob
|
2 |
import random
|
3 |
-
from abc import abstractmethod
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
from src.utils.filesys import getpath
|
@@ -49,8 +49,6 @@ class RLGenPolicy(GenPolicy):
|
|
49 |
if d < nz * self.n:
|
50 |
obs = torch.cat([torch.zeros([b, nz * self.n - d], device=self.device), obs], dim=-1)
|
51 |
with torch.no_grad():
|
52 |
-
# mus, sigmas, betas = self.model.get_intermediate(obs)
|
53 |
-
# print(mus[0].cpu().numpy(), '\n', betas[0].cpu().numpy(), '\n')
|
54 |
model_output, _ = self.model(obs)
|
55 |
return torch.clamp(model_output, -1, 1).squeeze().cpu().numpy()
|
56 |
|
@@ -60,43 +58,6 @@ class RLGenPolicy(GenPolicy):
|
|
60 |
n = load_cfgs(path, 'N')
|
61 |
return RLGenPolicy(model, n, device)
|
62 |
|
63 |
-
#
|
64 |
-
# class SunriseGenPolicy(GenPolicy):
|
65 |
-
# def __init__(self, models, n, device='cpu'):
|
66 |
-
# super(SunriseGenPolicy, self).__init__(n)
|
67 |
-
# for model in models:
|
68 |
-
# model.to(device)
|
69 |
-
# self.models = models
|
70 |
-
# self.m = len(self.models)
|
71 |
-
#
|
72 |
-
# self.agent = SunriseProxyAgent(models, device)
|
73 |
-
#
|
74 |
-
# def step(self, obs):
|
75 |
-
# actions = [m(obs.unsqueeze()).squeeze().cpu().numpy() for m in self.models]
|
76 |
-
# if len(obs.shape) == 1:
|
77 |
-
# return random.choice(actions)
|
78 |
-
# else:
|
79 |
-
# actions = np.array(actions)
|
80 |
-
# selections = [random.choice(range(self.m)) for _ in range(len(obs))]
|
81 |
-
# selected = [actions[s, i, :] for i, s in enumerate(selections)]
|
82 |
-
# return np.array(selected)
|
83 |
-
# #
|
84 |
-
# # def reset(self):
|
85 |
-
# # # self.agent.reset()
|
86 |
-
# # pass
|
87 |
-
#
|
88 |
-
# @staticmethod
|
89 |
-
# def from_path(path, device='cpu'):
|
90 |
-
# models = [
|
91 |
-
# torch.load(p, map_location=device)
|
92 |
-
# for p in glob.glob(getpath(path, 'policy*.pth'))
|
93 |
-
# ]
|
94 |
-
# n = load_cfgs(path, 'N')
|
95 |
-
# return SunriseGenPolicy(models, n, device)
|
96 |
-
#
|
97 |
-
# # @property
|
98 |
-
# # def device(self):
|
99 |
-
# # return self.agent.device
|
100 |
|
101 |
|
102 |
class EnsembleGenPolicy(GenPolicy):
|
@@ -114,20 +75,25 @@ class EnsembleGenPolicy(GenPolicy):
|
|
114 |
actions = []
|
115 |
with torch.no_grad():
|
116 |
for m in self.models:
|
117 |
-
a = m(o)
|
118 |
if type(a) == tuple:
|
119 |
a = a[0]
|
120 |
actions.append(torch.clamp(a, -1, 1).cpu().numpy())
|
121 |
if len(obs.shape) == 1:
|
122 |
return random.choice(actions)
|
123 |
else:
|
|
|
124 |
actions = np.array(actions)
|
|
|
125 |
selections = [random.choice(range(self.m)) for _ in range(len(obs))]
|
126 |
selected = [actions[s, i, :] for i, s in enumerate(selections)]
|
127 |
return np.array(selected)
|
128 |
|
129 |
@staticmethod
|
130 |
def from_path(path, device='cpu'):
|
|
|
|
|
|
|
131 |
models = [
|
132 |
torch.load(p, map_location=device)
|
133 |
for p in glob.glob(getpath(path, 'policy*.pth'))
|
@@ -141,9 +107,6 @@ class RandGenPolicy(GenPolicy):
|
|
141 |
super(RandGenPolicy, self).__init__(1)
|
142 |
|
143 |
def step(self, obs):
|
144 |
-
# if len(obs.shape) == 1:
|
145 |
-
# return sample_latvec(1).squeeze().numpy()
|
146 |
-
# else:
|
147 |
n = obs.shape[0]
|
148 |
return sample_latvec(n).squeeze().numpy()
|
149 |
|
|
|
1 |
import glob
|
2 |
import random
|
3 |
+
from abc import abstractmethod
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
from src.utils.filesys import getpath
|
|
|
49 |
if d < nz * self.n:
|
50 |
obs = torch.cat([torch.zeros([b, nz * self.n - d], device=self.device), obs], dim=-1)
|
51 |
with torch.no_grad():
|
|
|
|
|
52 |
model_output, _ = self.model(obs)
|
53 |
return torch.clamp(model_output, -1, 1).squeeze().cpu().numpy()
|
54 |
|
|
|
58 |
n = load_cfgs(path, 'N')
|
59 |
return RLGenPolicy(model, n, device)
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
|
63 |
class EnsembleGenPolicy(GenPolicy):
|
|
|
75 |
actions = []
|
76 |
with torch.no_grad():
|
77 |
for m in self.models:
|
78 |
+
a = m(o) # action model predict
|
79 |
if type(a) == tuple:
|
80 |
a = a[0]
|
81 |
actions.append(torch.clamp(a, -1, 1).cpu().numpy())
|
82 |
if len(obs.shape) == 1:
|
83 |
return random.choice(actions)
|
84 |
else:
|
85 |
+
# 这里对于每个observation, 选择m个模型, 每个模型都输出一个动作, 然后随机选择其中一个动作
|
86 |
actions = np.array(actions)
|
87 |
+
# 这里的self.m就是模型的数量, 等价于len(self.models)
|
88 |
selections = [random.choice(range(self.m)) for _ in range(len(obs))]
|
89 |
selected = [actions[s, i, :] for i, s in enumerate(selections)]
|
90 |
return np.array(selected)
|
91 |
|
92 |
@staticmethod
|
93 |
def from_path(path, device='cpu'):
|
94 |
+
"""
|
95 |
+
读取path中的所有模型
|
96 |
+
"""
|
97 |
models = [
|
98 |
torch.load(p, map_location=device)
|
99 |
for p in glob.glob(getpath(path, 'policy*.pth'))
|
|
|
107 |
super(RandGenPolicy, self).__init__(1)
|
108 |
|
109 |
def step(self, obs):
|
|
|
|
|
|
|
110 |
n = obs.shape[0]
|
111 |
return sample_latvec(n).squeeze().numpy()
|
112 |
|
src/smb/asyncsimlt.py
CHANGED
@@ -54,14 +54,11 @@ def _simlt_worker(remote, parent_remote, rfunc, resource):
|
|
54 |
min_dtw = min(min_dtw, vdtw)
|
55 |
remote.send((min_hm, min_dtw))
|
56 |
elif cmd == 'mpd':
|
57 |
-
# strpairs = data
|
58 |
hms, dtws = [], []
|
59 |
for strlvl1, strlvl2 in data:
|
60 |
lvl1, lvl2 = MarioLevel(strlvl1), MarioLevel(strlvl2)
|
61 |
hms.append(hamming_dis(lvl1, lvl2))
|
62 |
-
# dtws.append(lvl_dtw(lvl1, lvl2))
|
63 |
remote.send((hms, None))
|
64 |
-
# remote.send((hms, dtws))
|
65 |
else:
|
66 |
raise KeyError(f'Unknown command for simulation worker: {cmd}')
|
67 |
except EOFError:
|
@@ -70,6 +67,9 @@ def _simlt_worker(remote, parent_remote, rfunc, resource):
|
|
70 |
|
71 |
|
72 |
class AsycSimltPool:
|
|
|
|
|
|
|
73 |
def __init__(self, poolsize, queuesize=None, rfunc_name='default', verbose=True, **rsrc):
|
74 |
self.np, self.nq = poolsize, poolsize if queuesize is None else queuesize
|
75 |
self.waiting_queue = Queue(self.nq)
|
@@ -149,6 +149,7 @@ class AsycSimltPool:
|
|
149 |
for work_remote, remote in zip(self.work_remotes, self.remotes):
|
150 |
args = (work_remote, remote, rfunc, resource)
|
151 |
# daemon=True: if the main process crashes, we should not cause things to hang
|
|
|
152 |
process = ctx.Process(target=_simlt_worker, args=args, daemon=True) # pytype:disable=attribute-error
|
153 |
process.start()
|
154 |
self.processes.append(process)
|
@@ -162,12 +163,6 @@ class AsycSimltPool:
|
|
162 |
time.sleep(0.01)
|
163 |
|
164 |
def close(self):
|
165 |
-
# finish = False
|
166 |
-
# while not finish:
|
167 |
-
# self.refresh()
|
168 |
-
# finish = all(r for r in self.ready)
|
169 |
-
# time.sleep(0.01)
|
170 |
-
# self.__wait()
|
171 |
res = self.get(True)
|
172 |
for remote, p in zip(self.remotes, self.processes):
|
173 |
remote.send(('close', None))
|
|
|
54 |
min_dtw = min(min_dtw, vdtw)
|
55 |
remote.send((min_hm, min_dtw))
|
56 |
elif cmd == 'mpd':
|
|
|
57 |
hms, dtws = [], []
|
58 |
for strlvl1, strlvl2 in data:
|
59 |
lvl1, lvl2 = MarioLevel(strlvl1), MarioLevel(strlvl2)
|
60 |
hms.append(hamming_dis(lvl1, lvl2))
|
|
|
61 |
remote.send((hms, None))
|
|
|
62 |
else:
|
63 |
raise KeyError(f'Unknown command for simulation worker: {cmd}')
|
64 |
except EOFError:
|
|
|
67 |
|
68 |
|
69 |
class AsycSimltPool:
|
70 |
+
"""
|
71 |
+
异步池, 用于多进程马里奥模拟任务
|
72 |
+
"""
|
73 |
def __init__(self, poolsize, queuesize=None, rfunc_name='default', verbose=True, **rsrc):
|
74 |
self.np, self.nq = poolsize, poolsize if queuesize is None else queuesize
|
75 |
self.waiting_queue = Queue(self.nq)
|
|
|
149 |
for work_remote, remote in zip(self.work_remotes, self.remotes):
|
150 |
args = (work_remote, remote, rfunc, resource)
|
151 |
# daemon=True: if the main process crashes, we should not cause things to hang
|
152 |
+
# 开启多进程来做异步计算
|
153 |
process = ctx.Process(target=_simlt_worker, args=args, daemon=True) # pytype:disable=attribute-error
|
154 |
process.start()
|
155 |
self.processes.append(process)
|
|
|
163 |
time.sleep(0.01)
|
164 |
|
165 |
def close(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
res = self.get(True)
|
167 |
for remote, p in zip(self.remotes, self.processes):
|
168 |
remote.send(('close', None))
|
src/smb/proxy.py
CHANGED
@@ -9,10 +9,6 @@ from src.smb.level import MarioLevel, LevelRender
|
|
9 |
from src.utils.filesys import getpath
|
10 |
|
11 |
JVMPath = None
|
12 |
-
# JVMPath = '/home/cseadmin/java/jdk1.8.0_301/jre/lib/amd64/server/libjvm.so'
|
13 |
-
# JVMPath = '/home/liujl_lab/12132362/java/jdk1.8.0_301/jre/lib/amd64/server/libjvm.so'
|
14 |
-
# JVMPath = '/home/liujl_lab/12132333/java/jdk1.8.0_301/jre/lib/amd64/server/libjvm.so'
|
15 |
-
|
16 |
|
17 |
class MarioJavaAgents(Enum):
|
18 |
Runner = 'agents.robinBaumgarten'
|
@@ -27,7 +23,6 @@ class MarioProxy:
|
|
27 |
def __init__(self):
|
28 |
if not jpype.isJVMStarted():
|
29 |
jar_path = getpath('smb/Mario-AI-Framework.jar')
|
30 |
-
# print(f"-Djava.class.path={jar_path}/Mario-AI-Framework.jar")
|
31 |
jpype.startJVM(
|
32 |
jpype.getDefaultJVMPath() if JVMPath is None else JVMPath,
|
33 |
f"-Djava.class.path={jar_path}", '-Xmx2g'
|
@@ -137,6 +132,3 @@ class MarioProxy:
|
|
137 |
|
138 |
if __name__ == '__main__':
|
139 |
simulator = MarioProxy()
|
140 |
-
# lvl = MarioLevel.from_file('smb/levels/lvl-1.lvl')
|
141 |
-
# print(simulator.simulate_complete(lvl))
|
142 |
-
# print(simulator.play_game(lvl))
|
|
|
9 |
from src.utils.filesys import getpath
|
10 |
|
11 |
JVMPath = None
|
|
|
|
|
|
|
|
|
12 |
|
13 |
class MarioJavaAgents(Enum):
|
14 |
Runner = 'agents.robinBaumgarten'
|
|
|
23 |
def __init__(self):
|
24 |
if not jpype.isJVMStarted():
|
25 |
jar_path = getpath('smb/Mario-AI-Framework.jar')
|
|
|
26 |
jpype.startJVM(
|
27 |
jpype.getDefaultJVMPath() if JVMPath is None else JVMPath,
|
28 |
f"-Djava.class.path={jar_path}", '-Xmx2g'
|
|
|
132 |
|
133 |
if __name__ == '__main__':
|
134 |
simulator = MarioProxy()
|
|
|
|
|
|
src/utils/img.py
CHANGED
@@ -10,7 +10,6 @@ def make_img_sheet(imgs, ncols, x_margin=6, y_margin=6, save_path='./image.png',
|
|
10 |
w_canvas = (w + x_margin) * ncols - x_margin
|
11 |
h_canvas = (h + y_margin) * nrows - y_margin
|
12 |
canvas = Image.new('RGBA', (w_canvas, h_canvas), (0, 0, 0, 0))
|
13 |
-
# canvas.fill(margin_color)
|
14 |
for i in range(len(imgs)):
|
15 |
row_id, col_id = i // ncols, i % ncols
|
16 |
canvas.paste(imgs[i], ((w + x_margin) * col_id, (h + y_margin) * row_id), imgs[i])
|
|
|
10 |
w_canvas = (w + x_margin) * ncols - x_margin
|
11 |
h_canvas = (h + y_margin) * nrows - y_margin
|
12 |
canvas = Image.new('RGBA', (w_canvas, h_canvas), (0, 0, 0, 0))
|
|
|
13 |
for i in range(len(imgs)):
|
14 |
row_id, col_id = i // ncols, i % ncols
|
15 |
canvas.paste(imgs[i], ((w + x_margin) * col_id, (h + y_margin) * row_id), imgs[i])
|
test_ddpm.py
CHANGED
@@ -4,16 +4,10 @@ import torch
|
|
4 |
import torch.optim as optim
|
5 |
import torch.nn as nn
|
6 |
import logging
|
7 |
-
# from tqdm import tqdm
|
8 |
-
# from torch.utils.tensorboard import SummaryWriter
|
9 |
from src.ddpm.diffusion import Diffusion
|
10 |
from src.ddpm.modules import UNet
|
11 |
-
# from pytorch_model_summary import summary
|
12 |
-
# from matplotlib import pyplot as plt
|
13 |
from src.ddpm.dataset import create_dataloader
|
14 |
-
# from utils.plot import get_img_from_level
|
15 |
from pathlib import Path
|
16 |
-
# from src.smb.level import MarioLevel
|
17 |
import argparse
|
18 |
import datetime
|
19 |
|
@@ -21,16 +15,12 @@ from src.gan.gankits import process_onehot, get_decoder
|
|
21 |
from src.smb.level import MarioLevel, lvlhcat, save_batch
|
22 |
from src.utils.filesys import getpath
|
23 |
from src.utils.img import make_img_sheet
|
24 |
-
|
25 |
-
# sprite_counts = np.power(np.array([102573, 9114, 1017889, 930, 3032, 7330, 2278, 2279, 5227, 5229, 5419]), 1/4)
|
26 |
sprite_counts = np.power(np.array([
|
27 |
74977, 15252, 572591, 5826, 1216, 7302, 237, 237, 2852, 1074, 235, 304, 48, 96, 160, 1871, 936, 186, 428, 80, 428
|
28 |
]), 1/4
|
29 |
)
|
30 |
min_count = np.min(sprite_counts)
|
31 |
|
32 |
-
# filepath = Path(__file__).parent.resolve()
|
33 |
-
# DATA_PATH = os.path.join(filepath, "levels", "ground", "unique_onehot.npz")
|
34 |
|
35 |
def setup_logging(run_name, beta_schedule):
|
36 |
model_path = os.path.join("models", beta_schedule, run_name)
|
@@ -39,41 +29,8 @@ def setup_logging(run_name, beta_schedule):
|
|
39 |
os.makedirs(result_path, exist_ok=True)
|
40 |
return model_path, result_path
|
41 |
|
42 |
-
#
|
43 |
-
# fig = plt.figure(figsize=(30, 15))
|
44 |
-
# for i in range(len(sampled_images)):
|
45 |
-
# ax1 = fig.add_subplot(4, int(len(sampled_images)/4), i+1)
|
46 |
-
# ax1.tick_params(left=False, right=False, labelleft=False, labelbottom=False, bottom=False)
|
47 |
-
# level = sampled_images[i].argmax(dim=0).cpu().numpy()
|
48 |
-
# level_img = get_img_from_level(level)
|
49 |
-
# ax1.imshow(level_img)
|
50 |
-
# plt.savefig(os.path.join(result_path, f"{epoch:04d}_sample.png"))
|
51 |
-
# plt.close()
|
52 |
-
|
53 |
-
# def plot_training_images(epoch, original_img, x_t, noise, predicted_noise, reconstructed_img, training_result_path):
|
54 |
-
# fig = plt.figure(figsize=(15, 10))
|
55 |
-
# for i in range(2):
|
56 |
-
# ax1 = fig.add_subplot(2, 5, i*5+1)
|
57 |
-
# ax1.imshow(get_img_from_level(original_img[i].cpu().numpy()))
|
58 |
-
# ax1.set_title(f"Original {i}")
|
59 |
-
# ax2 = fig.add_subplot(2, 5, i*5+2)
|
60 |
-
# ax2.imshow(get_img_from_level(noise[i].cpu().numpy()))
|
61 |
-
# ax2.set_title(f"Noise {i}")
|
62 |
-
# ax3 = fig.add_subplot(2, 5, i*5+3)
|
63 |
-
# ax3.imshow(get_img_from_level(x_t.argmax(dim=1).cpu().numpy()[i]))
|
64 |
-
# ax3.set_title(f"x_t {i}")
|
65 |
-
# ax4 = fig.add_subplot(2, 5, i*5+4)
|
66 |
-
# ax4.imshow(get_img_from_level(predicted_noise[i].cpu().numpy()))
|
67 |
-
# ax4.set_title(f"Predicted Noise {i}")
|
68 |
-
# ax5 = fig.add_subplot(2, 5, i*5+5)
|
69 |
-
# ax5.imshow(get_img_from_level(reconstructed_img.probs.argmax(dim=-1).cpu().numpy()[i]))
|
70 |
-
# ax5.set_title(f"Reconstructed Image {i}")
|
71 |
-
# plt.savefig(os.path.join(training_result_path, f"{epoch:04d}.png"))
|
72 |
-
# plt.close()
|
73 |
-
|
74 |
def train(args):
|
75 |
-
# model_path, result_path = setup_logging(args.run_name, args.beta_schedule)
|
76 |
-
# training_result_path = os.path.join(result_path, "training")
|
77 |
path = getpath(args.res_path)
|
78 |
os.makedirs(path, exist_ok=True)
|
79 |
|
@@ -83,24 +40,15 @@ def train(args):
|
|
83 |
optimizer = optim.AdamW(model.parameters(), lr=args.lr)
|
84 |
mse = nn.MSELoss()
|
85 |
diffusion = Diffusion(device=device, schedule=args.beta_schedule)
|
86 |
-
# logger = SummaryWriter(os.path.join("logs", args.beta_schedule, args.run_name))
|
87 |
temperatures = torch.tensor(min_count / sprite_counts, dtype=torch.float32).to(device)
|
88 |
l = len(dataloader)
|
89 |
|
90 |
-
# print(summary(model, torch.zeros((64, MarioLevel.n_types, 14, 14)).to(device), diffusion.sample_timesteps(64).to(device), show_input=True))
|
91 |
-
|
92 |
-
# if args.resume_from != 0:
|
93 |
-
# checkpoint = torch.load(os.path.join(model_path, f'ckpt_{args.resume_from}'))
|
94 |
-
# model.load_state_dict(checkpoint['model_state_dict'])
|
95 |
-
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
96 |
|
97 |
for epoch in range(args.resume_from+1, args.resume_from+args.epochs+1):
|
98 |
logging.info(f"Starting epoch {epoch}:")
|
99 |
epoch_loss = {'rec_loss': 0, 'mse': 0, 'loss': 0}
|
100 |
-
# pbar = tqdm(dataloader)
|
101 |
for i, images in enumerate(dataloader):
|
102 |
images = images.to(device)
|
103 |
-
# print(images.shape)
|
104 |
t = diffusion.sample_timesteps(images.shape[0]).to(device) # random int from 1~1000
|
105 |
x_t, noise = diffusion.noise_images(images, t) # x_t: image with noise at t, noise: gaussian noise
|
106 |
predicted_noise = model(x_t.float(), t.float()) # returns predicted noise eps_theta
|
@@ -117,38 +65,13 @@ def train(args):
|
|
117 |
loss.backward()
|
118 |
optimizer.step()
|
119 |
|
120 |
-
# pbar.set_postfix(LOSS=loss.item())
|
121 |
-
# logger.add_scalar("Rec_loss", rec_loss.item(), global_step=(epoch - 1) * l + i)
|
122 |
-
# logger.add_scalar("MSE", mse_loss.item(), global_step=(epoch - 1) * l + i)
|
123 |
-
# logger.add_scalar("LOSS", loss.item(), global_step=(epoch - 1) * l + i)
|
124 |
-
|
125 |
-
# logger.add_scalar("Epoch_Rec_loss", epoch_loss['rec_loss']/l, global_step=epoch)
|
126 |
-
# logger.add_scalar("Epoch_MSE", epoch_loss['mse']/l, global_step=epoch)
|
127 |
-
# logger.add_scalar("Epoch_LOSS", epoch_loss['loss']/l, global_step=epoch)
|
128 |
print(
|
129 |
'\nIteration: %d' % epoch,
|
130 |
'rec_loss: %.5g' % (epoch_loss['rec_loss']/l),
|
131 |
'mse: %.5g' % (epoch_loss['mse']/l)
|
132 |
)
|
133 |
|
134 |
-
# if epoch % 20 == 19:
|
135 |
-
# sampled_images = diffusion.sample(model, n=50)
|
136 |
-
# imgs = [lvl.to_img() for lvl in process_onehot(sampled_images[-1])]
|
137 |
-
# make_img_sheet(imgs, 10, save_path=f'{args.res_path}/sample{epoch+1}.png')
|
138 |
-
|
139 |
-
# plot_images(epoch, sampled_images[-1], result_path)
|
140 |
-
# plot_training_images(epoch, original_img, x_t, noise.argmax(dim=1), predicted_noise.argmax(dim=1), reconstructed_img, training_result_path)
|
141 |
-
|
142 |
if epoch % 1000 == 0:
|
143 |
-
# torch.save(model.state_dict(), os.path.join(model_path, f"ckpt_{epoch:04d}.pt"))
|
144 |
-
# torch.save({
|
145 |
-
# 'epoch': epoch,
|
146 |
-
# 'model_state_dict': model.state_dict(),
|
147 |
-
# 'optimizer_state_dict': optimizer.state_dict(),
|
148 |
-
# 'Epoch_Rec_loss': epoch_loss['rec_loss']/l,
|
149 |
-
# 'Epoch_MSE': epoch_loss['mse']/l,
|
150 |
-
# 'Epoch_LOSS': epoch_loss['loss']/l
|
151 |
-
# }, getpath(f"{args.res_path}/ddpm_{epoch}.pt"))
|
152 |
itpath = getpath(path, f'it{epoch}')
|
153 |
os.makedirs(itpath, exist_ok=True)
|
154 |
model.save(getpath(path, itpath, 'ddpm.pth'))
|
@@ -173,11 +96,8 @@ def train(args):
|
|
173 |
def launch():
|
174 |
parser = argparse.ArgumentParser()
|
175 |
parser.add_argument("--epochs", type=int, default=10000)
|
176 |
-
# parser.add_argument("--data_path", type=str, default=DATA_PATH)
|
177 |
parser.add_argument("--batch_size", type=int, default=256)
|
178 |
parser.add_argument("--res_path", type=str, default='exp_data/DDPM')
|
179 |
-
# parser.add_argument("--image_size", type=int, default=14)
|
180 |
-
# parser.add_argument("--device", type=str, default="cuda")
|
181 |
parser.add_argument("--gpuid", type=int, default=0)
|
182 |
parser.add_argument("--lr", type=float, default=3e-4)
|
183 |
parser.add_argument("--beta_schedule", type=str, default="quadratic", choices=['linear', 'quadratic', 'sigmoid'])
|
|
|
4 |
import torch.optim as optim
|
5 |
import torch.nn as nn
|
6 |
import logging
|
|
|
|
|
7 |
from src.ddpm.diffusion import Diffusion
|
8 |
from src.ddpm.modules import UNet
|
|
|
|
|
9 |
from src.ddpm.dataset import create_dataloader
|
|
|
10 |
from pathlib import Path
|
|
|
11 |
import argparse
|
12 |
import datetime
|
13 |
|
|
|
15 |
from src.smb.level import MarioLevel, lvlhcat, save_batch
|
16 |
from src.utils.filesys import getpath
|
17 |
from src.utils.img import make_img_sheet
|
|
|
|
|
18 |
sprite_counts = np.power(np.array([
|
19 |
74977, 15252, 572591, 5826, 1216, 7302, 237, 237, 2852, 1074, 235, 304, 48, 96, 160, 1871, 936, 186, 428, 80, 428
|
20 |
]), 1/4
|
21 |
)
|
22 |
min_count = np.min(sprite_counts)
|
23 |
|
|
|
|
|
24 |
|
25 |
def setup_logging(run_name, beta_schedule):
|
26 |
model_path = os.path.join("models", beta_schedule, run_name)
|
|
|
29 |
os.makedirs(result_path, exist_ok=True)
|
30 |
return model_path, result_path
|
31 |
|
32 |
+
# 测试DDPM的模型训练
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
def train(args):
|
|
|
|
|
34 |
path = getpath(args.res_path)
|
35 |
os.makedirs(path, exist_ok=True)
|
36 |
|
|
|
40 |
optimizer = optim.AdamW(model.parameters(), lr=args.lr)
|
41 |
mse = nn.MSELoss()
|
42 |
diffusion = Diffusion(device=device, schedule=args.beta_schedule)
|
|
|
43 |
temperatures = torch.tensor(min_count / sprite_counts, dtype=torch.float32).to(device)
|
44 |
l = len(dataloader)
|
45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
for epoch in range(args.resume_from+1, args.resume_from+args.epochs+1):
|
48 |
logging.info(f"Starting epoch {epoch}:")
|
49 |
epoch_loss = {'rec_loss': 0, 'mse': 0, 'loss': 0}
|
|
|
50 |
for i, images in enumerate(dataloader):
|
51 |
images = images.to(device)
|
|
|
52 |
t = diffusion.sample_timesteps(images.shape[0]).to(device) # random int from 1~1000
|
53 |
x_t, noise = diffusion.noise_images(images, t) # x_t: image with noise at t, noise: gaussian noise
|
54 |
predicted_noise = model(x_t.float(), t.float()) # returns predicted noise eps_theta
|
|
|
65 |
loss.backward()
|
66 |
optimizer.step()
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
print(
|
69 |
'\nIteration: %d' % epoch,
|
70 |
'rec_loss: %.5g' % (epoch_loss['rec_loss']/l),
|
71 |
'mse: %.5g' % (epoch_loss['mse']/l)
|
72 |
)
|
73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
if epoch % 1000 == 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
itpath = getpath(path, f'it{epoch}')
|
76 |
os.makedirs(itpath, exist_ok=True)
|
77 |
model.save(getpath(path, itpath, 'ddpm.pth'))
|
|
|
96 |
def launch():
|
97 |
parser = argparse.ArgumentParser()
|
98 |
parser.add_argument("--epochs", type=int, default=10000)
|
|
|
99 |
parser.add_argument("--batch_size", type=int, default=256)
|
100 |
parser.add_argument("--res_path", type=str, default='exp_data/DDPM')
|
|
|
|
|
101 |
parser.add_argument("--gpuid", type=int, default=0)
|
102 |
parser.add_argument("--lr", type=float, default=3e-4)
|
103 |
parser.add_argument("--beta_schedule", type=str, default="quadratic", choices=['linear', 'quadratic', 'sigmoid'])
|
test_gen_log.py
DELETED
@@ -1,15 +0,0 @@
|
|
1 |
-
import time
|
2 |
-
import argparse
|
3 |
-
from tests import evaluate_rewards, evaluate_gen_log
|
4 |
-
|
5 |
-
|
6 |
-
if __name__ == '__main__':
|
7 |
-
parser = argparse.ArgumentParser()
|
8 |
-
parser.add_argument('--path', type=str)
|
9 |
-
parser.add_argument('--parallel', type=int, default=50)
|
10 |
-
parser.add_argument('--rfunc', type=str)
|
11 |
-
args = parser.parse_args()
|
12 |
-
start = time.time()
|
13 |
-
evaluate_gen_log(args.path, args.rfunc, parallel=args.parallel)
|
14 |
-
print(f'Evaluation for {args.path} finished,', '%.2f' % (time.time() - start))
|
15 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_gen_samples.py
DELETED
@@ -1,24 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
import argparse
|
3 |
-
import time
|
4 |
-
|
5 |
-
import numpy as np
|
6 |
-
from tests import evaluate_rewards, evaluate_mpd
|
7 |
-
from src.smb.level import load_batch
|
8 |
-
from src.utils.filesys import getpath
|
9 |
-
|
10 |
-
if __name__ == '__main__':
|
11 |
-
parser = argparse.ArgumentParser()
|
12 |
-
parser.add_argument('--path', type=str)
|
13 |
-
parser.add_argument('--parallel', type=int, default=50)
|
14 |
-
parser.add_argument('--rfunc', type=str)
|
15 |
-
args = parser.parse_args()
|
16 |
-
start = time.time()
|
17 |
-
lvls = load_batch(getpath(args.path, 'samples.lvls'))
|
18 |
-
rewards = [sum(item) for item in evaluate_rewards(lvls, args.rfunc, parallel=args.parallel)]
|
19 |
-
diversity = evaluate_mpd(lvls)
|
20 |
-
with open(getpath(args.path, 'performance.csv'), 'w') as f:
|
21 |
-
json.dump({'reward': np.mean(rewards), 'diversity': diversity}, f)
|
22 |
-
print(f'Evaluation for {args.path} finished,', '%.2f' % (time.time() - start))
|
23 |
-
|
24 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests.py
DELETED
@@ -1,140 +0,0 @@
|
|
1 |
-
import csv
|
2 |
-
import time
|
3 |
-
|
4 |
-
import torch
|
5 |
-
|
6 |
-
from plots import print_compare_tab_nonrl
|
7 |
-
from src.gan.gankits import *
|
8 |
-
from src.smb.level import *
|
9 |
-
from itertools import combinations, chain
|
10 |
-
from src.utils.filesys import getpath
|
11 |
-
from src.smb.asyncsimlt import AsycSimltPool
|
12 |
-
|
13 |
-
|
14 |
-
def evaluate_rewards(lvls, rfunc='default', dest_path='', parallel=1, eval_pool=None):
|
15 |
-
internal_pool = eval_pool is None
|
16 |
-
if internal_pool:
|
17 |
-
eval_pool = AsycSimltPool(parallel, rfunc_name=rfunc, verbose=False, test=True)
|
18 |
-
res = []
|
19 |
-
for lvl in lvls:
|
20 |
-
eval_pool.put('evaluate', (0, str(lvl)))
|
21 |
-
buffer = eval_pool.get()
|
22 |
-
for _, item in buffer:
|
23 |
-
res.append([sum(r) for r in zip(*item.values())])
|
24 |
-
if internal_pool:
|
25 |
-
buffer = eval_pool.close()
|
26 |
-
else:
|
27 |
-
buffer = eval_pool.get(True)
|
28 |
-
for _, item in buffer:
|
29 |
-
res.append([sum(r) for r in zip(*item.values())])
|
30 |
-
if len(dest_path):
|
31 |
-
np.save(dest_path, res)
|
32 |
-
return res
|
33 |
-
|
34 |
-
def evaluate_mpd(lvls, parallel=2):
|
35 |
-
task_datas = [[] for _ in range(parallel)]
|
36 |
-
for i, (A, B) in enumerate(combinations(lvls, 2)):
|
37 |
-
# lvlA, lvlB = lvls[i * 2], lvls[i * 2 + 1]
|
38 |
-
task_datas[i % parallel].append((str(A), str(B)))
|
39 |
-
|
40 |
-
hms, dtws = [], []
|
41 |
-
eval_pool = AsycSimltPool(parallel, verbose=False)
|
42 |
-
for task_data in task_datas:
|
43 |
-
eval_pool.put('mpd', task_data)
|
44 |
-
res = eval_pool.get(wait=True)
|
45 |
-
for task_hms, _ in res:
|
46 |
-
hms += task_hms
|
47 |
-
return np.mean(hms)
|
48 |
-
|
49 |
-
def evaluate_gen_log(path, rfunc_name, parallel=5):
|
50 |
-
f = open(getpath(f'{path}/step_tests.csv'), 'w', newline='')
|
51 |
-
wrtr = csv.writer(f)
|
52 |
-
cols = ['step', 'r-avg', 'r-std', 'diversity']
|
53 |
-
wrtr.writerow(cols)
|
54 |
-
start_time = time.time()
|
55 |
-
for lvls, name in traverse_batched_level_files(f'{path}/gen_log'):
|
56 |
-
step = name[4:]
|
57 |
-
rewards = [sum(item) for item in evaluate_rewards(lvls, rfunc_name, parallel=parallel)]
|
58 |
-
r_avg, r_std = np.mean(rewards), np.std(rewards)
|
59 |
-
mpd = evaluate_mpd(lvls, parallel=parallel)
|
60 |
-
line = [step, r_avg, r_std, mpd]
|
61 |
-
wrtr.writerow(line)
|
62 |
-
f.flush()
|
63 |
-
print(
|
64 |
-
f'{path}: step{step} evaluated in {time.time()-start_time:.1f}s -- '
|
65 |
-
+ '; '.join(f'{k}: {v}' for k, v in zip(cols, line))
|
66 |
-
)
|
67 |
-
f.close()
|
68 |
-
pass
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
if __name__ == '__main__':
|
74 |
-
# print_compare_tab_nonrl()
|
75 |
-
|
76 |
-
arr = [[1, 2], [1, 2]]
|
77 |
-
arr = [*chain(*arr)]
|
78 |
-
print(arr)
|
79 |
-
for i in range(5):
|
80 |
-
path = f'training_data/GAN{i}'
|
81 |
-
lvls = []
|
82 |
-
init_lateves = torch.tensor(np.load(getpath('analysis/initial_seg.npy')), device='cuda:0')
|
83 |
-
decoder = get_decoder(device='cuda:0')
|
84 |
-
init_seg_onehots = decoder(init_lateves.view(*init_lateves.shape, 1, 1))
|
85 |
-
gan = get_decoder(f'{path}/decoder.pth', device='cuda:0')
|
86 |
-
for init_seg_onehot in init_seg_onehots:
|
87 |
-
seg_onehots = gan(sample_latvec(25, device='cuda:0'))
|
88 |
-
a = init_seg_onehot.view(1, *init_seg_onehot.shape)
|
89 |
-
b = seg_onehots
|
90 |
-
# print(a.shape, b.shape)
|
91 |
-
segs = process_onehot(torch.cat([a, b], dim=0))
|
92 |
-
level = lvlhcat(segs)
|
93 |
-
lvls.append(level)
|
94 |
-
save_batch(lvls, getpath(path, 'samples.lvls'))
|
95 |
-
lvls = load_batch(f'{path}/samples.lvls')[:15]
|
96 |
-
imgs = [lvl.to_img() for lvl in lvls]
|
97 |
-
make_img_sheet(imgs, 1, save_path=f'generation_results/GAN/trial{i+1}/sample_lvls.png')
|
98 |
-
|
99 |
-
ts = torch.tensor([
|
100 |
-
[[0, 0], [0, 1], [0, 2]],
|
101 |
-
[[1, 0], [1, 1], [1, 2]],
|
102 |
-
])
|
103 |
-
print(ts.shape)
|
104 |
-
print(ts[[*range(2)], [1, 2], :])
|
105 |
-
task = 'fhp'
|
106 |
-
parallel = 50
|
107 |
-
samples = []
|
108 |
-
for algo in ['dvd', 'egsac', 'pmoe', 'sunrise', 'asyncsac', 'sac']:
|
109 |
-
for t in range(5):
|
110 |
-
lvls = load_batch(getpath('test_data', algo, task, f't{t + 1}', 'samples.lvls'))
|
111 |
-
samples += lvls
|
112 |
-
for l in ['0.0', '0.1', '0.2', '0.3', '0.4', '0.5']:
|
113 |
-
for t in range(5):
|
114 |
-
lvls = load_batch(getpath('test_data', f'varpm-{task}', f'l{l}_m5', f't{t + 1}', 'samples.lvls'))
|
115 |
-
samples += lvls
|
116 |
-
|
117 |
-
# task_datas = [[] for _ in range(parallel)]
|
118 |
-
# for i, (A, B) in enumerate(combinations(samples, 2)):
|
119 |
-
# lvlA, lvlB = lvls[i * 2], lvls[i * 2 + 1]
|
120 |
-
# task_datas[i % parallel].append((str(A), str(B)))
|
121 |
-
|
122 |
-
distmat = []
|
123 |
-
eval_pool = AsycSimltPool(parallel, verbose=False)
|
124 |
-
for A in samples:
|
125 |
-
eval_pool.put('mpd', [(str(A), str(B)) for B in samples])
|
126 |
-
res = eval_pool.get()
|
127 |
-
for task_hms, _ in res:
|
128 |
-
hms += task_hms
|
129 |
-
np.save(getpath('test_data', f'samples_dists-{task}.npy'), hms)
|
130 |
-
|
131 |
-
start = time.time()
|
132 |
-
samples = load_batch(getpath('test_data/varpm-fhp/l0.0_m2/t1/samples.lvls'))
|
133 |
-
distmat = []
|
134 |
-
for a in samples:
|
135 |
-
dist_list = []
|
136 |
-
for b in samples:
|
137 |
-
dist_list.append(hamming_dis(a, b))
|
138 |
-
distmat.append(dist_list)
|
139 |
-
print(time.time() - start)
|
140 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train.py
CHANGED
@@ -46,4 +46,6 @@ if __name__ == '__main__':
|
|
46 |
args = parser.parse_args()
|
47 |
|
48 |
entry = args.entry
|
|
|
|
|
49 |
entry(args)
|
|
|
46 |
args = parser.parse_args()
|
47 |
|
48 |
entry = args.entry
|
49 |
+
|
50 |
+
# entry是每一个模型的训练入口,具体函数在各个subparser内定义
|
51 |
entry(args)
|