baiyanlali-zhao commited on
Commit
3582c8a
1 Parent(s): 8be1cb6

添加注释

Browse files
README.md CHANGED
@@ -8,26 +8,39 @@ python_version: 3.9
8
  app_file: app.py
9
  pinned: false
10
  ---
 
 
11
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- # Negatively Correlated Ensemble RL
 
 
 
14
 
 
 
 
 
 
 
15
 
16
- ### Verified environment
17
- * Python 3.9.6
18
- * JPype 1.3.0
19
- * dtw 1.4.0
20
- * scipy 1.7.2
21
- * torch 1.8.2+cu111
22
- * numpy 1.20.3
23
- * gym 0.21.0
24
- * scipy 1.7.2
25
- * Pillow 10.0.0
26
- * matplotlib 3.6.3
27
- * pandas 1.3.2
28
- * sklearn 1.0.1
29
 
30
- ### How to use
31
 
32
  All training are launched by running `train.py` with option and arguments. For example, execute `python train.py ncesac --lbd 0.3 --m 5` will train NCERL with hyperparameters set as $\lambda = 0.3, m=5$.
33
  Plot script is `plots.py`
@@ -36,9 +49,38 @@ All training are launched by running `train.py` with option and arguments. For e
36
  * `python train.py sac`: to train a standard SAC as the policy for online game level generation
37
  * `python train.py asyncsac`: to train a SAC with an asynchronous evaluation environment as the policy for online game level generation
38
  * `python train.py ncesac`: to train an NCERL based on SAC as the policy for online game level generation
39
- * `python train.py egsac`: to train an episodic generative SAC (see paper [*The fun facets of Mario: Multifaceted experience-driven PCG via reinforcement learning*](https://dl.acm.org/doi/abs/10.1145/3555858.3563282?casa_token=AHQWYSj_GyoAAAAA:MhwOltqfijP1NQj-c6NaTQikCnlNwyaMky07gCvTK5ZlSq063ew40awAcqEcw6S5zG9Sq9ZyDsspuaM)) as the policy for online game level generation
40
  * `python train.py pmoe`: to train an episodic generative SAC (see paper [*Probabilistic Mixture-of-Experts for Efficient Deep Reinforcement Learning*](https://arxiv.org/abs/2104.09122)) as the policy for online game level generation
41
  * `python train.py sunrise`: to train a SUNRISE (see paper [*SUNRISE: A Simple Unified Framework for Ensemble Learning in Deep Reinforcement Learning*](https://proceedings.mlr.press/v139/lee21g.html)) as the policy for online game level generation
42
  * `python train.py dvd`: to train a DvD-SAC (see paper [*Effective Diversity in Population Based Reinforcement Learning*](https://proceedings.neurips.cc/paper_files/paper/2020/hash/d1dc3a8270a6f9394f88847d7f0050cf-Abstract.html)) as the policy for online game level generation
43
 
44
  For the training arguments, please refer to the help `python train.py [option] --help`
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  app_file: app.py
9
  pinned: false
10
  ---
11
+ ![alt text](./media/banner.png)
12
+ # Negatively Correlated Ensemble RL
13
 
14
+ ## 环境安装
15
+ 创建conda环境
16
+ ```bash
17
+ conda create -n ncerl python=3.9
18
+ ```
19
+ 安装环境依赖
20
+ ```bash
21
+ pip install -r requirements.txt
22
+ ```
23
+ 注:该程序不需要您使用任何显卡,但是需要安装pytorch。如果您的显卡支持cuda,那么请安装cuda版本,否则安装cpu版本。使用cuda版本可以提高推理速度。
24
 
25
+ 切换conda环境
26
+ ```
27
+ conda activate ncerl
28
+ ```
29
 
30
+ ## 快速开始
31
+ 如果您想查看效果,可以通过
32
+ ```
33
+ python app.py
34
+ ```
35
+ 后打开命令行显示连接互动查看。
36
 
37
+ 也可以通过运行
38
+ ```
39
+ python generate_and_play.py
40
+ ```
41
+ 后查看`models/example_policy/samples.png`查看生成效果。
 
 
 
 
 
 
 
 
42
 
43
+ ## 开始训练
44
 
45
  All training are launched by running `train.py` with option and arguments. For example, execute `python train.py ncesac --lbd 0.3 --m 5` will train NCERL with hyperparameters set as $\lambda = 0.3, m=5$.
46
  Plot script is `plots.py`
 
49
  * `python train.py sac`: to train a standard SAC as the policy for online game level generation
50
  * `python train.py asyncsac`: to train a SAC with an asynchronous evaluation environment as the policy for online game level generation
51
  * `python train.py ncesac`: to train an NCERL based on SAC as the policy for online game level generation
52
+ * `python train.py egsac`: to train an episodic generative SAC (see paper [*The fun facets of Mario: Multifaceted experience-driven PCG via reinforcement learning*](https://dl.acm.org/doi/abs/10.1145/3555858.3563282)) as the policy for online game level generation
53
  * `python train.py pmoe`: to train an episodic generative SAC (see paper [*Probabilistic Mixture-of-Experts for Efficient Deep Reinforcement Learning*](https://arxiv.org/abs/2104.09122)) as the policy for online game level generation
54
  * `python train.py sunrise`: to train a SUNRISE (see paper [*SUNRISE: A Simple Unified Framework for Ensemble Learning in Deep Reinforcement Learning*](https://proceedings.mlr.press/v139/lee21g.html)) as the policy for online game level generation
55
  * `python train.py dvd`: to train a DvD-SAC (see paper [*Effective Diversity in Population Based Reinforcement Learning*](https://proceedings.neurips.cc/paper_files/paper/2020/hash/d1dc3a8270a6f9394f88847d7f0050cf-Abstract.html)) as the policy for online game level generation
56
 
57
  For the training arguments, please refer to the help `python train.py [option] --help`
58
+
59
+ ## 目录结构
60
+ ```
61
+ NCERL-DIVERSE-PCG/
62
+ * analysis/
63
+ * generate.py 未使用
64
+ * tests.py 做evaluation使用
65
+ * media/ markdown素材文件
66
+ * models/
67
+ * example_policy/ 做生成展示使用
68
+ * smb/ 马里奥仿真以及图片资源数据
69
+ * src/
70
+ * ddpm/ ddpm模型相关目录
71
+ * drl/ drl模型、训练目录
72
+ * env/ 马里奥gym环境和reward function
73
+ * gan/ gan模型、训练目录
74
+ * olgen/ 在线生成环境与policy目录
75
+ * rlkit/ 强化学习使用部件目录
76
+ * smb/ 马里奥与仿真器交互组件以及多进程异步池组件
77
+ * utils/ 一些功能性文件
78
+ * training_data/ 训练数据
79
+ * README.md 当前文件
80
+ * app.py 用于gradio展示用途文件
81
+ * generate_and_play.py 用于非gradio展示文件
82
+ * train.py 训练文件
83
+ * test_ddpm.py 测试训练ddpm文件
84
+ * requirements.txt 环境依赖文件
85
+ ```
86
+
analysis/tests.py CHANGED
@@ -36,7 +36,6 @@ def evaluate_rewards(lvls, rfunc='default', dest_path='', parallel=1, eval_pool=
36
 
37
  def evaluate_mnd(lvls, refs, parallel=2):
38
  eval_pool = AsycSimltPool(parallel, verbose=False, refs=[str(ref) for ref in refs])
39
- # m, _ = len(lvls), len(refs)
40
  res = []
41
  for lvl in lvls:
42
  eval_pool.put('mnd_item', str(lvl))
@@ -49,7 +48,6 @@ def evaluate_mnd(lvls, refs, parallel=2):
49
  def evaluate_mpd(lvls, parallel=2):
50
  task_datas = [[] for _ in range(parallel)]
51
  for i, (A, B) in enumerate(combinations(lvls, 2)):
52
- # lvlA, lvlB = lvls[i * 2], lvls[i * 2 + 1]
53
  task_datas[i % parallel].append((str(A), str(B)))
54
 
55
  hms, dtws = [], []
@@ -73,7 +71,6 @@ def evaluate_gen_log(path, parallel=5):
73
  step = name[4:]
74
  rewards = [sum(item) for item in evaluate_rewards(lvls, rfunc_name, parallel=parallel)]
75
  r_avg, r_std = np.mean(rewards), np.std(rewards)
76
- # mpd_hm, mpd_dtw = evaluate_mpd(lvls, parallel=parallel)
77
  mpd = evaluate_mpd(lvls, parallel=parallel)
78
  line = [step, r_avg, r_std, mpd, '']
79
  wrtr.writerow(line)
 
36
 
37
  def evaluate_mnd(lvls, refs, parallel=2):
38
  eval_pool = AsycSimltPool(parallel, verbose=False, refs=[str(ref) for ref in refs])
 
39
  res = []
40
  for lvl in lvls:
41
  eval_pool.put('mnd_item', str(lvl))
 
48
  def evaluate_mpd(lvls, parallel=2):
49
  task_datas = [[] for _ in range(parallel)]
50
  for i, (A, B) in enumerate(combinations(lvls, 2)):
 
51
  task_datas[i % parallel].append((str(A), str(B)))
52
 
53
  hms, dtws = [], []
 
71
  step = name[4:]
72
  rewards = [sum(item) for item in evaluate_rewards(lvls, rfunc_name, parallel=parallel)]
73
  r_avg, r_std = np.mean(rewards), np.std(rewards)
 
74
  mpd = evaluate_mpd(lvls, parallel=parallel)
75
  line = [step, r_avg, r_std, mpd, '']
76
  wrtr.writerow(line)
app.py CHANGED
@@ -9,7 +9,6 @@ sys.path.append(path.dirname(path.abspath(__file__)))
9
 
10
 
11
  from src.olgen.ol_generator import VecOnlineGenerator
12
- # from src.olgen.olg_game import MarioOnlineGenGame
13
  from src.olgen.olg_policy import RLGenPolicy
14
  from src.smb.level import save_batch
15
  from src.utils.filesys import getpath
@@ -21,7 +20,7 @@ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
21
 
22
  def generate_and_play():
23
  path = 'models/example_policy'
24
- # Generate with example policy model
25
  N, L = 8, 10
26
  plc = RLGenPolicy.from_path(path, device)
27
  generator = VecOnlineGenerator(plc, g_device=device)
@@ -29,14 +28,9 @@ def generate_and_play():
29
  os.makedirs(fd, exist_ok=True)
30
 
31
  lvls = generator.generate(N, L)
32
- # save_batch(lvls, f'{path}/samples.lvls')
33
  imgs = [lvl.to_img() for lvl in lvls]
34
  return imgs
35
- # make_img_sheet(imgs, 1, save_path=f'{path}/samples.png')
36
 
37
- # # Play with the example policy model
38
- # game = MarioOnlineGenGame(path)
39
- # game.play()
40
 
41
 
42
  with gr.Blocks(title="NCERL Demo") as demo:
 
9
 
10
 
11
  from src.olgen.ol_generator import VecOnlineGenerator
 
12
  from src.olgen.olg_policy import RLGenPolicy
13
  from src.smb.level import save_batch
14
  from src.utils.filesys import getpath
 
20
 
21
  def generate_and_play():
22
  path = 'models/example_policy'
23
+ # 使用example policy做生成
24
  N, L = 8, 10
25
  plc = RLGenPolicy.from_path(path, device)
26
  generator = VecOnlineGenerator(plc, g_device=device)
 
28
  os.makedirs(fd, exist_ok=True)
29
 
30
  lvls = generator.generate(N, L)
 
31
  imgs = [lvl.to_img() for lvl in lvls]
32
  return imgs
 
33
 
 
 
 
34
 
35
 
36
  with gr.Blocks(title="NCERL Demo") as demo:
generate_and_play.py CHANGED
@@ -1,5 +1,5 @@
1
  import os
2
-
3
  from src.olgen.ol_generator import VecOnlineGenerator
4
  from src.olgen.olg_game import MarioOnlineGenGame
5
  from src.olgen.olg_policy import RLGenPolicy
@@ -7,12 +7,14 @@ from src.smb.level import save_batch
7
  from src.utils.filesys import getpath
8
  from src.utils.img import make_img_sheet
9
 
 
 
10
  if __name__ == '__main__':
11
  path = 'models/example_policy'
12
  # Generate with example policy model
13
  N, L = 8, 10
14
- plc = RLGenPolicy.from_path(path)
15
- generator = VecOnlineGenerator(plc)
16
  fd, _ = os.path.split(getpath(path))
17
  os.makedirs(fd, exist_ok=True)
18
 
@@ -22,6 +24,7 @@ if __name__ == '__main__':
22
  make_img_sheet(imgs, 1, save_path=f'{path}/samples.png')
23
 
24
  # # Play with the example policy model
25
- # game = MarioOnlineGenGame(path)
26
- # game.play()
 
27
  pass
 
1
  import os
2
+ import torch
3
  from src.olgen.ol_generator import VecOnlineGenerator
4
  from src.olgen.olg_game import MarioOnlineGenGame
5
  from src.olgen.olg_policy import RLGenPolicy
 
7
  from src.utils.filesys import getpath
8
  from src.utils.img import make_img_sheet
9
 
10
+ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
11
+
12
  if __name__ == '__main__':
13
  path = 'models/example_policy'
14
  # Generate with example policy model
15
  N, L = 8, 10
16
+ plc = RLGenPolicy.from_path(path, device=device)
17
+ generator = VecOnlineGenerator(plc, g_device=device)
18
  fd, _ = os.path.split(getpath(path))
19
  os.makedirs(fd, exist_ok=True)
20
 
 
24
  make_img_sheet(imgs, 1, save_path=f'{path}/samples.png')
25
 
26
  # # Play with the example policy model
27
+ # 请保证您的电脑上已经安装了jvm, 并且在命令行中输入java可以看到Java的信息
28
+ game = MarioOnlineGenGame(path)
29
+ game.play()
30
  pass
media/banner.png ADDED
models/example_policy/samples.lvls CHANGED
@@ -1,135 +1,135 @@
1
- --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
2
- --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
3
- --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
4
- ---------------------------------#----------------------------------------------------------------------------------------------------------------------------------------------
5
- ------------------------------oo-------------------------------o---------------------------------------------------------------o------------------------------------------------
6
- -------------------------------------------------------------------------------------------------o------------------------------------------------------------------------------
7
- --------------------------------------SSSSSSSSSS---------------------Q--------------------------------------------------------------QQQ-----------------------------QQQQ--------
8
- --------------So----------------------------------------------------------------------------------------------------------------------------------------------------------------
9
- ----------------------------------------------o----------------------o---------------------------------------K----------------------------------------------------------------o-
10
- -----------#---------------------o---------------------------------------------------------------------------2------------------------------------------------------------------
11
- ---------####--------------------------------oS------------------#---SoS-----US------------------------------U-------------------#--SSSS-----US--------tt--------##-###S-----US-
12
- ---------####----------tt-----T------------------------TT----#--TT-----------------------TT------------B---------------TT----#T-TT---------------------tt--------##-------------
13
- --------########-------Tt-----T------------------------TT----TT-TT-----K----------------TTT------------B---------------TT----TT-TT---------------------tt----TT--##-------------
14
- -------#########--gggg-Tt---kkT------k-----kk-----gggg-TT---kTT-T#-k-k-g--k-----k-ty----TTT--ggg---k-gog----kkk---gggg-TT---kTT-T--k-k-g--k-k-----g----tt---kkg--##k-k-g--k-k---
15
- XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
16
- XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
17
  ;
18
- ------------------------------------------------------------------------------------------------------------------------------------------S----------------------------S--S-----
19
- ----------------------------------------------------------------------S-------------------------------------------------------------------------------------------S-SSSSSS------
20
- -----------------------------------------------------------------------------------------------------------------------------------------%%-------------------------------------
21
- ----------------------------------S------------------------------------------------------------------------------------------------------||-------------------------------------
22
- ----------------------------------------------------------------------------Koo---------------------------------------o------------------||----------------------------o-o-----o
23
- -----------------------------------------------------------------------------------------------------------------------------------------||-------------------------------------
24
- ----------------S--Q--SSoSS--SSS--o-----------------QQoo--------------SSS----SSS%---SS-----------------------U-------SSSS-------------SSSSSSSS-------------------------SS-------
25
- S-------------------------------------------------------------------------------|--------------------------------o--------------------------------------------------------------
26
- SSS-So---------------------------------------S-S--------------------------------|-------------o--------------------------------------------------------------K------------------
27
- ----------------------------------------------S--------o------------------------|--------------------------------S-------------------------------------------2------------oo----
28
- ----------------Q---QS@Q----S@SSS-------------S--------2-----U-----------------S|------------US--------------U-------------------------------SS--------------U-------------%----
29
- ----------T--------------------------------------------tt---------------#-------|------B---------------TT----#-----------------------------------------B-------------------|----
30
- ---------TT--------------------------------K-----------tt--------------##--#-#--|------B---------------TT---TTT---------#-------------t----------------B-------------------|----
31
- ---------TT----#---k------k----------------b---g--gggg-tt---k--------####----#--|-gggggb--k-k-k---g----TT---kT#-------###-------------t------k-g---k-gog----kkk------------|----
32
- ---XXX-XXXXXXXXXXXXXXXXXXXXXXXXX------XXXXXXoXXXXXXXXXXXXXXXXXXXXoXXXXXXXX--oXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX%XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX--%%%%-----|---@
33
- ---XXX-X-XXXXXXXXXXXXXXXXXXXXXXX------XXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX@--XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX|-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX---||------|----
34
  ;
35
- --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
36
- --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
37
- --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
38
- --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
39
- ------------------------------oo-----------------------------------------------o------------------------------------------------------------------------------------------------
40
- --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
41
- ----------------%S---So--------------QQ--------------------------------------------------------------Q-QQ--------S--QQSSQSSSSSSS-----QQQ----------------------------------------
42
- ----------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------
43
- ----------------|--------------------------------------------K--------------------------------##-------------oo------------------------------oo---------------------------------
44
- ----------------|--------------------------------------------U-------------------o-----o-----###------g-------------------------------------------------------------------------
45
- ----------------|----------------#---SoS-----US--------------U-------------------------tt----###--#-####Q---S@S-#------------US--------------U-------------------------------o--
46
- ----------------|------TT-----K-TT---------------------B------K--------TT----#T--------tt---###------------------------------##--------K---------------TT----TT--------t--------
47
- ----------------|------TT-----U-TT-----K---------------B---------------TT----TT--------tt--####-----------------------------###------------------------TT----TT--------t--------
48
- --kk-----------g|-gggg-Tt---k-U-T#-k-k----k--------k---t----k-----gggg-TT---kTT-Tg-----tt--####----k------k-k-----ggggg----####----k-kkyk---kkk--ggg-g-TT---TTT------k-tt----#--
49
- XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX--XXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
50
- XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
51
  ;
52
  --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
53
  --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
54
  --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
55
- --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
56
- ------------------------------oo-----------------------------------------------------------------------------------------------o------------------------------------------------
57
- ---------------------------------o-----------------------------------------------------------------------------------------------------------------o----------------------------
58
- ----------------%----So--------------Q---------------QQQQ------------Qo--------------QQQQS-----#--------------------------------SooS--SS-----S-----SQQQ---------------o---------
59
- ----------------|------------------------------------------------------------------------------#--------------------------------------------------------------------------------
60
- ----------------|----------------------------------------------------------------------------###-------------K------------------------------------------------o-----------------
61
- ----------------|----------------------------------------------------------------------------###-------------2---------------------------------o--------------------------------
62
- -------------oo-|----------------#--USoS-----US--#------------------------------Q-Q----QQ----###----------------------------------UQS------------------------US--------------o--
63
- ----------------|------TT-----K-TT---------------##----t---------------t---------------------###-------B---------------TT----#T----------------------------------------tt-------
64
- ---------------@|------TT-----U-TT-----K---------#---------------------t--------------------####-------B---------------TT----TT----------------------------------------tt-------
65
- ---gg----------g|-gggg-Tt---k-U-T--k-k----k------#-k--kk-----k-----k-gog----kkk---or--------####---k-gog----k-k---gggg-TT---kTT--------------------kgggg--k-k-----ggggott---kkk-
66
- XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX%%%%%-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX--XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
67
- XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-|XX--XXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
68
- ;
69
  --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
70
  --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
 
71
  --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
72
- ----------------------------------------------------------------------------------------------------SSSS------------------------------------------------------------------------
73
- -----------------------------------------------o--------------------------------------------------------------------------------------------------------------------------------
 
 
 
 
 
 
 
74
  --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
75
- --------------------QS--------------------------------Q-Q------------------------------------U--SSSSSSSSo----SSS---S@S@QQ-------%---SS------------------------------QQQQ--------
76
- --------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------
77
- -----------------------------------------------------------------------------K---------------------------------------o----------|----o--------o-------------------------------o-
78
- -----------#-----------------------------------------------------------------2------------------g-------------------------------|------------------------------------g----------
79
- ---------TT#-------------------------------------T#--S#S-----US--------------U---------------U--S--SS---------S-----USSS-----US-|----S-2-----US------------------######S-----US-
80
- ---------TT----------------------------TT----TT-TT---------------------B---------------TT----#---------K------------------------|------K--K------------tt----#---##-------------
81
- --------TTT----T-----------------------TT----TT------------------------B---------------TT----TT---------------------------------|------B---------------Tt----TT--##-------------
82
- T-------TTT----T-----gg-------kg--gggg-TT---kTT-----------k--------k-gog----k-k---g----TT---kT#-------k-----k------kgggg----k---|--k-gog--k-k-----ggg--Tt----kT--##k---g--k-k---
83
- XXXXXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
84
- X-XXXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
 
 
 
 
 
85
  ;
86
- --------------------------------------------------------------------------------------------------------------------------S-----------------------------------------------------
87
- -----------------------------------------------------------------------------------------------------------------S-SSSSS-SS-----------------------------------------------------
88
- --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
89
- ------------------------------------------------------------------------------------------------SSSSSS--------------------------------------------------------------------------
90
- --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
91
- --------------------------------------------------------------------------------------------------------------------------------------------------------------oo----------------
92
- ------SSS---SS--------------------------------------SQo------------------------------Q--------------------------@SSSSSSSQSSSSSSS----SQo-----------------------SS----QQQQQSS--S@S
93
- ------------------------------------------------------------------------------------------------------------------------------------------------S--------------#----------------
94
- -------------------------------------------------------------K-----------------------o-------------------------------------------------------U--S--------------#----------------
95
- -------------------------------------------------------------------------------------------------------o---------------g---------------------------------------#----------------
96
- ---------S---@S------------------#--USSS-----US-----US-2-----U---------tt-----------USSS-----US--------tt--------SQ-SSSQQ----US--------2-----U----------------##Q-Q--QQQQS---o--
97
- -----------------------tt----TT-TT---------------------K---------------tt-------T----------------------tt----T-------------------------K--K------------------###----------------
98
- -----------------------Tt----TT-TT---------------------B---------------tt------------------------------tt----T-------------------------B-------------------#####----------------
99
- ----ggk-----k-----gggg-Tt---kkTTT--k-k----k--------k-gog--k-k-----g----tt---kkg----k---g--k-----------ttt---kkk----k---------------k-gog--k-k--------------#####----------------
100
- XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
101
- XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
102
  ;
103
- -------------------------------------------------------k------------------------------------------------------------------------------------------------------------------------
104
- -------------------------------------S------------------------------------------------------------------------------------------------------------------------------------------
105
- --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
106
- -------SSSSSSSSS----------------------------------------------------------------------------------------------------------------------------------------------------------------
107
- ------------------------------oo---------------------------------------------------------------o-----------o-o----o----------o--------------------------------------------------
108
- ---------------------------------------------o-------------------------------------------------------------------------------------------------o--------------------------------
109
- ----------------%S---So----------------Q--------------SSSSSSSSSS-------------------------------------SQSoSSS-SS----S--SSSS----o-----QQ---------S--------------------------------
110
- ----------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------
111
- ----------------|------------------------------------------------------------U-------------------------------------------------------------------------------K------------------
112
- ----------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------
113
- ----SS----------|----------------T--S##SS----US------------SSS---------------U------------------------QQQ----USS--Qo-----#------QSQ---SSSSSSS%S--------------U---------------o--
114
- ----------------|------TT-----K-TT-----------------------------------------------------TT----TT-------------------------##-------------------|---------B---------------Tt----#--
115
- ----------------|------TT-----U-TT-------------------------------------B---------------TT----TT------------------------###---#---------------|---------B---------------TT----TT-
116
- ----------------|-gggg-Tt---k-U-TT---k----b------------------------k---b----kkk---gggg-TT---kTT----U------------------####--------or-k-------|-----k-gog----kkk---ggg--Tt----kT-
117
- XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXXX--XX---------X--XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXX--XXXXXXXXXX--XXXXXXXX---XXXXX%XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
118
- XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXX---X---------X--XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXX--X-XXX-XXXX--XXXXXXXX---XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
119
  ;
120
- --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
121
- --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
122
- --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
123
- ----------S-------------------------------------------------------------------------------------------------------------------------------------SSS-----------------------------
124
- -----------------------------U-o-----------------------------------------------o------------------------------------------------------------------------------------------------
125
- -------------------------------------------------o---------------------------------------------o--------------------------------------------------------------------------------
126
- ------------------------------------------SS---S%--------------------------------------------------------------------QQ-------------QQoo-----U-------QQQo------------Qo---------
127
- ------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------
128
- ---------#-----T--------------------------------|------------------------------------o-------------------------------------------------------K----------------------------------
129
- --------#------o--------------------------------|------------------------------------------------------o------------------------------------------------------------------------
130
- -------TT------T--------------------------------|----So2-----US---------------------USSS-----@S--------tt-----------USoS-----U------U--2-----U--###--------------------------U--
131
- -------TT----T---------TT----TT--------o--------|------K--K------------TT----TT-T----------------------tt----#---------K---------------K--------###--------------------B--------
132
- ------#TT---TT---------TT----TT--------#t-------|------B---------------TT----TT------------------------tt----T---------B---------------B--------###--------------------B--------
133
- -----##TT---TT----gg-#-Tt---kkT---gggg-TT--#y---|-gk-gog--k-k----ggggg-TT---kTT----k------k-------g----tt---kkT-T--k-kkb--k-k------k-gob--k-k---##----k-----------gggggb----k-k-
134
- XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
135
- XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
 
1
+ ----------------------------------------------------------------------------------------------------------S---------------------------------------------------------------------
2
+ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
3
+ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
4
+ ----------------------------------------------------------------------------------------------------------------------------------------------------------------SSS-------------
5
+ ------------------------------oo---------------------------------------------------------------------------------------------------------------o--------------------------------
6
+ --------------oo----------------------------------------------------------------------------------------------------------------------------------------------------------------
7
+ -------SS---%%%%%----So----------S--QSSSoS---SS---------------------QQQQQSSS-SSS---------------o-------------------------------U--------------------QS--------------------------
8
+ -------------||-|---------------------------------------------------------------------------------------------------------------------------------------------------------------
9
+ -------------||-|--------------------------------------------K---------------------------------------------------------------oo-------------------------------------------------
10
+ -------------||-|--------------------------------------------2---------------------------------------------oo----------------------------------------------------------o--------
11
+ -------QSSS-%%%%|---------------Q#-----------SS-----------------------QQQSSSS@S--------------------------S%%%----------------@S---------------------USSS-----US--------tt-------
12
+ -------------||-|------TT-----K------------------------B------K--------------------------------------------|---------------------------tt----#-------------------------tt----T--
13
+ -------------||-|------TT-----U------------------------B---------------------------------------------------|-----------B---------------Tt----TT------------------------tt----T--
14
+ ---t---------||-|-gggg-Tt---k-U-------k-k----------k-gog----k-k----k-------------------------#-------------|------gk---b----k-k---ggg--Tt----kT----k-oog--k-k----------tt---kkk-
15
+ XX%XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXXXXXXXXXXXX-XXXXXXXXXX-XXXXXXXXXXXXXXXXXXXX---%%%%%%%%|----XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
16
+ X-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXXXXXXXXXXXX-XXXXXXXXXX-XXXXXXXXXXXXXXXXXXXX----||||||-|----XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
17
  ;
18
+ --------------------------------------------------------------------------k-----------------------------------------------------------------------------------------------------
19
+ ---------------------------------------------------------S----------------K-----------------------------------------------------------------------------------------------------
20
+ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
21
+ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
22
+ ------------------------------oo-------------------------o----------------------------------------------------------------------------------------------------------------------
23
+ ---------------------------------------------------------------------------------------------------------------o----------------------------------------------------------------
24
+ -------------------------------------Q----------------SSSSS--SSS----QQQQQ--------------------------------------------------------------QQ-----------SQo-------------------------
25
+ -------------------------------------------------------------------------------o------------------------------------------------------------------------------------------------
26
+ -------------------------------------o---------------------------------------------------------------o---------------------------------------oo--------------U------------------
27
+ -----------------------------------------------------------------T--------------------------------------------------------------------g-----------------------------------------
28
+ ------------------------------------USSS-----US------------SSSS--T#--------------------------o------USSS-----@S--------tt-------#-#-####-----@S--------2-----U------------------
29
+ -----------------------tt-----T-TT-------------------------------T#----K---------------t--------T----------------------tt-------###--------------------K--K------------Tt----#--
30
+ -----------------------Tt-----T-T--------------------------------##--------------------t-----TT------------------------tt-------T#-----K---------------B---------------TT----TT-
31
+ ---------------o--gggg-Tt---kkk-T--k-k-g--k---------------------TT#----k----k-k---kggg-b----kkg----k------k-------g----tt---kkg-T#-k------k--------k-gog--k-k-----gg---TT----TT-
32
+ XXXXXXXX----XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-------XXX---XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
33
+ XXXXXXXX-X-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX---X---XX----XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
34
  ;
35
+ -------K----------------------------------------------------------------------------------k---------------------------------------------------------------------------S---------
36
+ ----------------------------------------------------------------------------------------------------------------------------------------------------------------------S---------
37
+ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
38
+ ----------------------------------------------------------------------------------------------------------------SSSSSSSSSSSSSSSS------------------------------------------------
39
+ -----------------------------------------------o----------------------------------------------------------------S--------------S------------------------------------------------
40
+ -------------ooo------------------------------------------------------------------------------------------------SSS------ooo---S----------------------------------o-------------
41
+ ----------%%%%%%-%-------------U----------------SSoSSSSSS----SSS--------------------QSQQQ----SSS------------------SSSSSSSSSSSSSS----QQSSSSS------------SS-SSSS---S%--------SS---
42
+ -----------||||--|------------------------------------------------------------------------------------------------------------------------------------------------|----TT-------
43
+ -----------||||--|-----------oo----------------------------------------------K-----------------------o---------------------------------------oo-------------------|----TT-------
44
+ ---------ooo|||--|--------------------------------------------oo-------------2------------------------------------------------------------------------------------|----TT---SS--
45
+ ------%%%%%%%%K--|-----------@o-----------------QSQSS---------SS--------------------Q--------USS----USSS----S@S-----------------------------S@S-------------S@S---|---###--S----
46
+ -------||||||-o--|-----t---------------TT----TT------------------------B------------------------T-----------------------------------------------------------------|--------S----
47
+ -------||||||%%%-|-----t---------------TT----TT------------------------B------------------------------------------------t-----------------------------------------|--------S----
48
+ ---r---||||||-|--|kk---b-----gkg--gggg-TT---kTT-----ggg------------k-gog----k-k----k--------k------k------k-----T------tt------g--kggk--------------kk---------k--|-------oy----
49
+ %%%%%%%%%%%||-|-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXX--XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXXXXXXXXXXX-SS--SXX-X--
50
+ -|||||||||-||-|-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXXXXXXXXXXXXXXXXX---------X-X
51
  ;
52
  --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
53
  --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
54
  --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
55
+ ----------------------------------------------------------------SSSS------------------------------------------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
57
  --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
58
+ --------------------------------------SS-----S------QQ------------------------------Q---QS---SSS----SQo-----------------------------QS----------------------------oo------------
59
  --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
60
+ -------------------------------------------------------------------------------------------------------------U------------------------------------------------------------------
61
+ -----------------------------------------------------------------------o---------o---------------------------------------------------------------------o------------------------
62
+ ------------------------------------oo-------o-------2-------U---------tt-------Q#QQQ---Q----U---------2-----U----------------------USoS-----US--------tt----o---------------SSS
63
+ -----------------------tt----TT------------------------K---------------tt----T-------------------------K--K------------Tt----#-------------------------tt----#------------------
64
+ -----------------------Tt----TT------------------------B---------------tt----TT--------------#---------B---------------TT----TT------------------------tt----TT-----------------
65
+ --kk--------------gggg-Tt---kkT--------------------k-k-b----k-----g----tt---kkT---ok--------k------k-gog----k-----g----TT----TT----k-kog--k-k----------tt----k------------k-----
66
+ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXX
67
+ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-XXXX
68
+ ;
69
  --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
70
+ -------------------------------------------------------------K-K-------------------------------------------------SS-------------------------------------------------------------
71
+ ---------------------------------------------------------------o----------------------------------------------------------------------------------------------------------------
72
+ ------------------------------%%-----------------------------%%t------------------------------------------------SSSSSSSSSSSSSSSS------------------------------------------------
73
+ ------------------------------||-----------------------------||o------------------------------------------------S---SSS------S-------------------------------------------------o
74
+ ------------------------------||-----------------------------||%------------------------------------------------SSSSSSSSSSSoSSSS------------------------------------------------
75
+ ------------------------S-----|o---------------U------%%%%%--|||S--SSS-------SSS--------------------------------SSSSSSSSSSSSSSSS-----QQ---------%---SS--------------------------
76
+ ------------------------------||-----------------------|||---|||--------------------------------------------------------------------------------|-------------------------------
77
+ ------------------------------||-----o-------oo--------|||---|||-----------------------------K--------------------------------------------------|------------Ko-----------------
78
+ ------------------------------||-----------------------|||-o-SS|-g----------------------------------------------ooooooo--------S----------------|-------------------------------
79
+ ---------T--------------------||-----S-------@S--------|||%%%--|QSQSSSSSSSS-S@S%-------------U---------------U--SSSSSSS--------S-#--------------|----S-------US-----------------
80
+ ---------TT---TT--------------||----------K------------|||-|---|---------------|-----------------------BB----#-------------------------t--------|------K--K------------TT----TT-
81
+ --B------TT---TT--------------||-------B---------------|||-|---|---------------|-----------------------Tt----TT---------------------------------|------B---------------TT----TT-
82
+ --b------TT---TT--------------|r--gk---b----k-k--------|||-|---r---k------k----|--kk---k----kkk---ggg--Tt---kkg------ggg-------g---k--kk----k---|--k-gog--k-k-----gggg-TT---kTT-
83
+ XXXXXXXXXXXXXXXX---------X---X%%XXXXXXXXXXXXXXXX-------|||-|%%%%XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
84
+ XXXXXXXXXXXXXXXX--------XX---X|XXXXXXXXXXXXXXXXX-------|||-|-||-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
85
  ;
86
+ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
87
+ --------------------------------------------------------------------------------------------------K-----------------------------------------------------------------------------
88
+ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
89
+ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
90
+ -----------------------------------------------------------------------------------------------o-------------------------------o------------------------------------------------
91
+ -------------------------------------------------------------------------------o---------------------------------------------------------------------------------ooo-----ooooo--
92
+ -------QQ-----------Q-QQQ------------SSSQS---SSS----SQQ----------------------%%%----QQQ------------------------2------SSSS-----------Qo-------------QQQQSSS--S--%%--------------
93
+ ------------------------------------------------------------------------------|------------------B---------------------------o----------------------------------||--------------
94
+ -------------------------------------------------------K----------------------|------------------------------------------------------------------------------oo-||--------------
95
+ ------------------------------------------------------------------------------|-----------------o-o-------------------------------------------------------------||--------------
96
+ -------------------#------------------Q--------------2-2-----U---------------oo-----------------------------------------------------------------------QQQ---S@S-||---2---%%%----
97
+ --------------------------------------t----------------K--K-------------------|--------@----------T------o-----------------------------B------------------------||--------|-----
98
+ --------------------------------------t------#---------B----------------------|--------tt---------TT----TT-----------------------------B------------------------||-----K--|----T
99
+ ------------k------k--------k--------kk------#-----k-gog----k-k---------------##---ygg-tt---k-----TT----TTT-k--------kk------------k-gog----kkk----k-k------k---|g-ggg-b--|----o
100
+ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX------%XXXX-XXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX--XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
101
+ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX------|-XXXXXXX-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX--XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
102
  ;
103
+ ------------------------------------------------------------------------------------------------------------------------------------------S-------------------------------------
104
+ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
105
+ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
106
+ ------------------------------------------------------------------------------------------------SSSSSSS---------SS----S---SSSSSS------------------------------------------------
107
+ -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------o
108
+ -----------------------------------------------------------------------------------------------o--------------------------------------------------------------------------------
109
+ S----SQQQ-----------------------SSSSSSSS-----SSS------------------------------------------------------------------------------------------------%---SS--------------------------
110
+ ------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------
111
+ ---------------------------------------------o---------------K-----------------------o------------------------------------------------TU--------|----o-------oo-----------------
112
+ -----------------------o--------o----------------------------2-------------------------------------------------------S----------------T------o--|-------------------------------
113
+ -SSS-SSSSS--SU---------tt--------------------@SS-----------------------------U------USSS-----US--------t-----------------------------S@#--------|----SS------@S--------------o--
114
+ -----------------------tt----T-------------------------B---------------TT----#--T----------------------tt----------------T------S---------------|----------------------TT----#--
115
+ -----------------------tt----TT------------------------B---------------TT----TT-----------------------ttt----T----------TTT---------------K-----|----------------------TT----TT-
116
+ ---k--g-----------ggg--tt---kkT----kgggg--k-k------k-gog----kkk---ggg--TT---kTT----k------k-----------ttt---kk----------TTT---------------U-----|--k-g-g--k-k-----gggg-TT---kTg-
117
+ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX---XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
118
+ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX---XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
119
  ;
120
+ ---------------------------------------------------------------------------------------------------K--S---S---------------------------------------------------------------------
121
+ --------------------------------------------------------------------------------------------------SS----------------------------------------------------------------------------
122
+ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
123
+ --------------------------------------------------------------------------------------------------------------ooSS--------------------------------------------------------------
124
+ --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
125
+ -----------------------------------------------------------------o--------------------------------------------------------------------------------o------------o----------------
126
+ -------------------------------------Q--------------QQQQQS------------------------------oo-----U----------------------SSSS----to----------------%%%-----------------QQ---------U
127
+ -------------------------------------------------------------------------------------------------------o----------------------K------------------|------------------------------
128
+ -----------------------------------------------------------------------------K---------------o-----------------------------------------------K---|----------------------------o-
129
+ -----gg----------------------------------------------------------------------U-------------------------------SS------------------------------U---|------------------------------
130
+ Q-#####-----------------------------USSS-----US--------QQ-----------------------------------S@S-------------------------------Q------------------|---------------------------US-
131
+ -----------------------tt----TT-TT-------------------------------------B------K----------------------------------------------##--------B------K--|------------------------------
132
+ -----------------------Tt----TT-T--------------------------------------B----------------------------------------------------###--------B---------|------------T-----------------
133
+ ------------k-----gggg-Tt---kkT-T--k-k-g--k--------------------k---k-gog----k-k----k-------------------------r--------------##-----k-gog----k-k--|-----------oro--gggg-g----k---
134
+ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX--XX----XX--XXXX--------%%%%%%%%XXXXXXXXXX--XXXXXXXXXXXXXXXXXXXX-|XXXXXXXXXXX%%%XXXXXXXXXXXXXXXX
135
+ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX--XX---XXX--XXXX---------||||||-XXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXX-|XXXX-X-XXX--|-XXXXXXXXXXXXXXXX
models/example_policy/samples.png CHANGED
plots.py DELETED
@@ -1,733 +0,0 @@
1
- import glob
2
- import json
3
- import os
4
- import re
5
-
6
- import numpy as np
7
- import pandas as pds
8
- import matplotlib
9
- import matplotlib.pyplot as plt
10
- from math import sqrt
11
- import torch
12
- from root import PRJROOT
13
- from sklearn.manifold import TSNE
14
- from itertools import product, chain
15
- # from src.drl.drl_uses import load_cfgs
16
- from src.gan.gankits import get_decoder, process_onehot
17
- from src.gan.gans import nz
18
- from src.smb.level import load_batch, hamming_dis, lvlhcat
19
- from src.utils.datastruct import RingQueue
20
- from src.utils.filesys import load_dict_json, getpath
21
- from src.utils.img import make_img_sheet
22
- from torch.distributions import Normal
23
-
24
- matplotlib.rcParams["axes.formatter.limits"] = (-5, 5)
25
-
26
-
27
- def print_compare_tab():
28
- rand_lgp, rand_fhp, rand_divs = load_dict_json(
29
- 'test_data/rand_policy/performance.csv', 'lgp', 'fhp', 'diversity'
30
- )
31
- rand_performance = {'lgp': rand_lgp, 'fhp': rand_fhp, 'diversity': rand_divs}
32
-
33
- def _print_line(_data, minimise=False):
34
- means = _data.mean(axis=-1)
35
- stds = _data.std(axis=-1)
36
- max_i, min_i = np.argmax(means), np.argmin(means)
37
- mean_str_content = [*map(lambda x: '%.4g' % x, _data.mean(axis=-1))]
38
- std_str_content = [*map(lambda x: '$\pm$%.3g' % x, _data.std(axis=-1))]
39
- if minimise:
40
- mean_str_content[min_i] = r'\textbf{%s}' % mean_str_content[min_i]
41
- mean_str_content[max_i] = r'\textit{%s}' % mean_str_content[max_i]
42
- std_str_content[min_i] = r'\textbf{%s}' % std_str_content[min_i]
43
- std_str_content[max_i] = r'\textit{%s}' % std_str_content[max_i]
44
- else:
45
- mean_str_content[max_i] = r'\textbf{%s}' % mean_str_content[max_i]
46
- mean_str_content[min_i] = r'\textit{%s}' % mean_str_content[min_i]
47
- std_str_content[max_i] = r'\textbf{%s}' % std_str_content[max_i]
48
- std_str_content[min_i] = r'\textit{%s}' % std_str_content[min_i]
49
- print(' &', ' & '.join(mean_str_content), r'\\')
50
- print(' & &', ' & '.join(std_str_content), r'\\')
51
- pass
52
-
53
- def _print_block(_task):
54
- fds = [
55
- f'sac/{_task}', f'egsac/{_task}', f'asyncsac/{_task}',
56
- f'pmoe/{_task}', f'dvd/{_task}', f'sunrise/{_task}',
57
- f'varpm-{_task}/l0.0_m5', f'varpm-{_task}/l0.1_m5', f'varpm-{_task}/l0.2_m5',
58
- f'varpm-{_task}/l0.3_m5', f'varpm-{_task}/l0.4_m5', f'varpm-{_task}/l0.5_m5'
59
- ]
60
- rewards, divs = [], []
61
- for fd in fds:
62
- rewards.append([])
63
- divs.append([])
64
- # print(getpath())
65
- for path in glob.glob(getpath('test_data', fd, '**', 'performance.csv'), recursive=True):
66
- reward, div = load_dict_json(path, 'reward', 'diversity')
67
- rewards[-1].append(reward)
68
- divs[-1].append(div)
69
- rewards = np.array(rewards)
70
- divs = np.array(divs)
71
-
72
- print(' & \\multirow{2}{*}{Reward}')
73
- _print_line(rewards)
74
- print(' \\cline{2-14}')
75
- print(' & \\multirow{2}{*}{Diversity}')
76
- _print_line(divs)
77
- print(' \\cline{2-14}')
78
- print(' & \\multirow{2}{*}{G-mean}')
79
- gmean = np.sqrt(rewards * divs)
80
- _print_line(gmean)
81
-
82
- print(' \\cline{2-14}')
83
- print(' & \\multirow{2}{*}{N-rank}')
84
- r_rank = np.zeros_like(rewards.flatten())
85
- r_rank[np.argsort(-rewards.flatten())] = np.linspace(1, len(r_rank), len(r_rank))
86
-
87
- d_rank = np.zeros_like(divs.flatten())
88
- d_rank[np.argsort(-divs.flatten())] = np.linspace(1, len(r_rank), len(r_rank))
89
- n_rank = (r_rank.reshape([12, 5]) + d_rank.reshape([12, 5])) / (2 * 5)
90
- _print_line(n_rank, True)
91
-
92
- print(' \\multirow{8}{*}{MarioPuzzle}')
93
- _print_block('fhp')
94
- print(' \\midrule')
95
- print(' \\multirow{8}{*}{MultiFacet}')
96
- _print_block('lgp')
97
- pass
98
-
99
- def print_compare_tab_nonrl():
100
- # rand_lgp, rand_fhp, rand_divs = load_dict_json(
101
- # 'test_data/rand_policy/performance.csv', 'lgp', 'fhp', 'diversity'
102
- # )
103
- # rand_performance = {'lgp': rand_lgp, 'fhp': rand_fhp, 'diversity': rand_divs}
104
-
105
- def _print_line(_data, minimise=False):
106
- means = _data.mean(axis=-1)
107
- stds = _data.std(axis=-1)
108
- max_i, min_i = np.argmax(means), np.argmin(means)
109
- mean_str_content = [*map(lambda x: '%.4g' % x, _data.mean(axis=-1))]
110
- std_str_content = [*map(lambda x: '$\pm$%.3g' % x, _data.std(axis=-1))]
111
- if minimise:
112
- mean_str_content[min_i] = r'\textbf{%s}' % mean_str_content[min_i]
113
- mean_str_content[max_i] = r'\textit{%s}' % mean_str_content[max_i]
114
- std_str_content[min_i] = r'\textbf{%s}' % std_str_content[min_i]
115
- std_str_content[max_i] = r'\textit{%s}' % std_str_content[max_i]
116
- else:
117
- mean_str_content[max_i] = r'\textbf{%s}' % mean_str_content[max_i]
118
- mean_str_content[min_i] = r'\textit{%s}' % mean_str_content[min_i]
119
- std_str_content[max_i] = r'\textbf{%s}' % std_str_content[max_i]
120
- std_str_content[min_i] = r'\textit{%s}' % std_str_content[min_i]
121
- print(' &', ' & '.join(mean_str_content), r'\\')
122
- print(' & &', ' & '.join(std_str_content), r'\\')
123
- pass
124
-
125
- def _print_block(_task):
126
- fds = [
127
- f'GAN-{_task}', f'DDPM-{_task}',
128
- f'varpm-{_task}/l0.0_m5', f'varpm-{_task}/l0.1_m5', f'varpm-{_task}/l0.2_m5',
129
- f'varpm-{_task}/l0.3_m5', f'varpm-{_task}/l0.4_m5', f'varpm-{_task}/l0.5_m5'
130
- ]
131
- rewards, divs = [], []
132
- for fd in fds:
133
- rewards.append([])
134
- divs.append([])
135
- # print(getpath())
136
- for path in glob.glob(getpath('test_data', fd, '**', 'performance.csv'), recursive=True):
137
- reward, div = load_dict_json(path, 'reward', 'diversity')
138
- rewards[-1].append(reward)
139
- divs[-1].append(div)
140
- rewards = np.array(rewards)
141
- divs = np.array(divs)
142
-
143
- print(' & \\multirow{2}{*}{Reward}')
144
- _print_line(rewards)
145
- print(' \\cline{2-10}')
146
- print(' & \\multirow{2}{*}{Diversity}')
147
- _print_line(divs)
148
- print(' \\cline{2-10}')
149
- # print(' & \\multirow{2}{*}{G-mean}')
150
- # gmean = np.sqrt(rewards * divs)
151
- # _print_line(gmean)
152
- #
153
- # print(' \\cline{2-10}')
154
- # print(' & \\multirow{2}{*}{N-rank}')
155
- # r_rank = np.zeros_like(rewards.flatten())
156
- # r_rank[np.argsort(-rewards.flatten())] = np.linspace(1, len(r_rank), len(r_rank))
157
- #
158
- # d_rank = np.zeros_like(divs.flatten())
159
- # d_rank[np.argsort(-divs.flatten())] = np.linspace(1, len(r_rank), len(r_rank))
160
- # n_rank = (r_rank.reshape([8, 5]) + d_rank.reshape([8, 5])) / (2 * 5)
161
- # _print_line(n_rank, True)
162
-
163
- print(' \\multirow{4}{*}{MarioPuzzle}')
164
- _print_block('fhp')
165
- print(' \\midrule')
166
- print(' \\multirow{4}{*}{MultiFacet}')
167
- _print_block('lgp')
168
- pass
169
-
170
- def plot_cmp_learning_curves(task, save_path='', title=''):
171
- plt.style.use('seaborn')
172
- colors = [plt.plot([0, 1], [-1000, -1000])[0].get_color() for _ in range(6)]
173
- plt.cla()
174
- plt.style.use('default')
175
-
176
- # colors = ('#5D2CAB', '#005BD4', '#007CE4', '#0097DD', '#00ADC4', '#00C1A5')
177
- def _get_algo_data(fd):
178
- res = []
179
- for i in range(1, 6):
180
- path = getpath(fd, f't{i}', 'step_tests.csv')
181
- try:
182
- data = pds.read_csv(path)
183
- trajectory = [
184
- [float(item['step']), float(item['r-avg']), float(item['diversity'])]
185
- for _, item in data.iterrows()
186
- ]
187
- trajectory.sort(key=lambda x: x[0])
188
- res.append(trajectory)
189
- if len(trajectory) != 26:
190
- print('Not complete (%d)/26:' % len(trajectory), path)
191
- except FileNotFoundError:
192
- print(path)
193
- res = np.array(res)
194
- # rdsum = res[:, :, 1] + res[:, :, 2]
195
- gmean = np.sqrt(res[:, :, 1] * res[:, :, 2])
196
- steps = res[0, :, 0]
197
- # r_avgs = np.mean(res[:, :, 1], axis=0)
198
- # r_stds = np.std(res[:, :, 1], axis=0)
199
- # divs = np.mean(res[:, :, 2], axis=0)
200
- # div_std = np.std(res[:, :, 2], axis=0)
201
- _performances = {
202
- 'reward': (np.mean(res[:, :, 1], axis=0), np.std(res[:, :, 1], axis=0)),
203
- 'diversity': (np.mean(res[:, :, 2], axis=0), np.std(res[:, :, 2], axis=0)),
204
- # 'rdsum': (np.mean(rdsum, axis=0), np.std(rdsum, axis=0)),
205
- 'gmean': (np.mean(gmean, axis=0), np.std(gmean, axis=0)),
206
- }
207
- # print(_performances['gmean'])
208
- return steps, _performances
209
-
210
- def _plot_criterion(_ax, _criterion):
211
- i, j, k = 0, 0, 0
212
- for algo, (steps, _performances) in performances.items():
213
- avgs, stds = _performances[_criterion]
214
- if '\lambda' in algo:
215
- ls = '-'
216
- _c = colors[i]
217
- i += 1
218
- elif algo in {'SAC', 'EGSAC', 'ASAC'}:
219
- ls = ':'
220
- _c = colors[j]
221
- j += 1
222
- else:
223
- ls = '--'
224
- _c = colors[j]
225
- j += 1
226
- _ax.plot(steps, avgs, color=_c, label=algo, ls=ls)
227
- _ax.fill_between(steps, avgs - stds, avgs + stds, color=_c, alpha=0.15)
228
- _ax.grid(False)
229
- # plt.plot(steps, avgs, label=algo)
230
- # plt.plot(_performances, label=algo)
231
- pass
232
- _ax.set_xlabel('Time step')
233
-
234
- fig, ax = plt.subplots(1, 3, figsize=(9.6, 3.2), dpi=250, width_ratios=[1, 1, 1])
235
- # fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(5, 4), dpi=256)
236
- # fig, ax1 = plt.subplots(1, 1, figsize=(8, 3), dpi=256)
237
- # ax2 = ax1.twinx()
238
- # fig = plt.plot(figsize=(4, 3), dpi=256)
239
- performances = {
240
- 'SUNRISE': _get_algo_data(f'test_data/sunrise/{task}'),
241
- '$\lambda$=0.0': _get_algo_data(f'test_data/varpm-{task}/l0.0_m5'),
242
- 'DvD': _get_algo_data(f'test_data/dvd/{task}'),
243
- '$\lambda$=0.1': _get_algo_data(f'test_data/varpm-{task}/l0.1_m5'),
244
- 'PMOE': _get_algo_data(f'test_data/pmoe/{task}'),
245
- '$\lambda$=0.2': _get_algo_data(f'test_data/varpm-{task}/l0.2_m5'),
246
- 'SAC': _get_algo_data(f'test_data/sac/{task}'),
247
- '$\lambda$=0.3': _get_algo_data(f'test_data/varpm-{task}/l0.3_m5'),
248
- 'EGSAC': _get_algo_data(f'test_data/egsac/{task}'),
249
- '$\lambda$=0.4': _get_algo_data(f'test_data/varpm-{task}/l0.4_m5'),
250
- 'ASAC': _get_algo_data(f'test_data/asyncsac/{task}'),
251
- '$\lambda$=0.5': _get_algo_data(f'test_data/varpm-{task}/l0.5_m5'),
252
- }
253
- # _plot_algo(*_get_algo_data(glob.glob(getpath('test_data/SAC', '**', 'step_tests.csv'))), 'SAC')
254
- # _plot_algo(*_get_algo_data(glob.glob(getpath('test_data/EGSAC', '**', 'step_tests.csv'))), 'EGSAC')
255
- # _plot_algo(*_get_algo_data(glob.glob(getpath('test_data/AsyncSAC', '**', 'step_tests.csv'))), 'AsyncSAC')
256
- # _plot_algo(*_get_algo_data(glob.glob(getpath('test_data/SUNRISE', '**', 'step_tests.csv'))), 'SUNRISE')
257
- # _plot_algo(*_get_algo_data(glob.glob(getpath('test_data/DvD-ES', '**', 'step_tests.csv'))), 'DvD-ES')
258
- # _plot_algo(*_get_algo_data(glob.glob(getpath('test_data/lbd-m-crosstest/l0.04_m5', '**', 'step_tests.csv'))), 'NCESAC')
259
-
260
-
261
-
262
- _plot_criterion(ax[0], 'reward')
263
- _plot_criterion(ax[1], 'diversity')
264
- # _plot_criterion(ax[2], 'rdsum')
265
- _plot_criterion(ax[2], 'gmean')
266
- # ax[0].set_title(f'{title} reward')
267
- ax[0].set_title(f'Cumulative Reward')
268
- ax[1].set_title('Diversity Score')
269
- # ax[2].set_title('Summation')
270
- ax[2].set_title('G-mean')
271
- # plt.title(title)
272
-
273
- lines, labels = fig.axes[-1].get_legend_handles_labels()
274
- fig.suptitle(title, fontsize=14)
275
- plt.tight_layout(pad=0.5)
276
- if save_path:
277
- plt.savefig(getpath(save_path))
278
- else:
279
- plt.show()
280
-
281
- plt.cla()
282
- plt.figure(figsize=(9.6, 2.4), dpi=250)
283
- plt.grid(False)
284
- plt.axis('off')
285
- plt.yticks([1.0])
286
- plt.legend(
287
- lines, labels, loc='lower center', ncol=6, edgecolor='white', fontsize=15,
288
- columnspacing=0.8, borderpad=0.16, labelspacing=0.2, handlelength=2.4, handletextpad=0.3
289
- )
290
- plt.tight_layout(pad=0.5)
291
- plt.show()
292
- pass
293
-
294
- def plot_crosstest_scatters(rfunc, xrange=None, yrange=None, title=''):
295
- def get_pareto():
296
- all_points = list(chain(*scatter_groups.values())) + cmp_points
297
- res = []
298
- for p in all_points:
299
- non_dominated = True
300
- for q in all_points:
301
- if q[0] >= p[0] and q[1] >= p[1] and (q[0] > p[0] or q[1] > p[1]):
302
- non_dominated = False
303
- break
304
- if non_dominated:
305
- res.append(p)
306
- res.sort(key=lambda item:item[0])
307
- return np.array(res)
308
- def _hex_color(_c):
309
- return
310
- scatter_groups = {}
311
- all_lbd = set()
312
- # Initialise
313
- plt.style.use('seaborn-v0_8-dark-palette')
314
- # plt.figure(figsize=(4, 4), dpi=256)
315
- plt.figure(figsize=(2.5, 2.5), dpi=256)
316
- plt.axes().set_axisbelow(True)
317
-
318
- # Competitors' performances
319
- cmp_folders = ['asyncsac', 'egsac', 'sac', 'sunrise', 'dvd', 'pmoe']
320
- cmp_names = ['ASAC', 'EGSAC', 'SAC', 'SUNRISE', 'DvD', 'PMOE']
321
- cmp_labels = ['A', 'E', 'S', 'R', 'D', 'M']
322
- cmp_markers = ['2', 'x', '+', 'o', '*', 'D']
323
- cmp_sizes = [42, 20, 32, 16, 24, 10, 10]
324
- cmp_points = []
325
- for name, folder, label, mk, s in zip(cmp_names, cmp_folders, cmp_labels, cmp_markers, cmp_sizes):
326
- path_fmt = getpath('test_data', folder, rfunc, '*', 'performance.csv')
327
- # print(path_fmt)
328
- xs, ys = [], []
329
- for path in glob.glob(path_fmt, recursive=True):
330
- # print(path)
331
- try:
332
- x, y = load_dict_json(path, 'reward', 'diversity')
333
- xs.append(x)
334
- ys.append(y)
335
- cmp_points.append([x, y])
336
- # plt.text(x, y, label, size=7, weight='bold', va='center', ha='center', color='#202020')
337
- except FileNotFoundError:
338
- print(path)
339
- if label in {'A', 'E', 'S'}:
340
- plt.scatter(xs, ys, marker=mk, zorder=2, s=s, label=name, color='#202020')
341
- else:
342
- plt.scatter(
343
- xs, ys, marker=mk, zorder=2, s=s, label=name, color=[0., 0., 0., 0.],
344
- edgecolors='#202020', linewidths=1
345
- )
346
- # NCESAC performances
347
- for path in glob.glob(getpath('test_data', f'varpm-{rfunc}', '**', 'performance.csv'), recursive=True):
348
- try:
349
- x, y = load_dict_json(path, 'reward', 'diversity')
350
- key = path.split('\\')[-3]
351
- _, mtxt = key.split('_')
352
- ltxt, _ = key.split('_')
353
- lbd = float(ltxt[1:])
354
- # if mtxt in {'m2', 'm3', 'm4'}:
355
- # continue
356
- all_lbd.add(lbd)
357
- if key not in scatter_groups.keys():
358
- scatter_groups[key] = []
359
- scatter_groups[key].append([x, y])
360
- except Exception as e:
361
- print(path)
362
- print(e)
363
-
364
- palette = plt.get_cmap('seismic')
365
- color_x = [0.2, 0.33, 0.4, 0.61, 0.67, 0.79]
366
- colors = {lbd: matplotlib.colors.to_hex(c) for c, lbd in zip(palette(color_x), sorted(all_lbd))}
367
- colors = {0.0: '#150080', 0.1: '#066598', 0.2: '#01E499', 0.3: '#9FD40C', 0.4: '#F3B020', 0.5: '#FA0000'}
368
- for lbd in sorted(all_lbd): plt.plot([-20], [-20], label=f'$\\lambda={lbd:.1f}$', lw=6, c=colors[lbd])
369
- markers = {2: 'o', 3: '^', 4: 'D', 5: 'p', 6: 'h'}
370
- msizes = {2: 25, 3: 25, 4: 16, 5: 28, 6: 32}
371
- for key, group in scatter_groups.items():
372
- ltxt, mtxt = key.split('_')
373
- l = float(ltxt[1:])
374
- m = int(mtxt[1:])
375
- arr = np.array(group)
376
- plt.scatter(
377
- arr[:, 0], arr[:, 1], marker=markers[m], s=msizes[m], color=[0., 0., 0., 0.], zorder=2,
378
- edgecolors=colors[l], linewidths=1
379
- )
380
-
381
- plt.xlim(xrange)
382
- plt.ylim(yrange)
383
- # plt.xlabel('Task Reward')
384
- # plt.ylabel('Diversity')
385
- # plt.legend(ncol=2)
386
- # plt.legend(
387
- # ncol=2, loc='lower left', columnspacing=1.2, borderpad=0.0,
388
- # handlelength=1, handletextpad=0.5, framealpha=0.
389
- # )
390
- pareto = get_pareto()
391
- plt.plot(
392
- pareto[:, 0], pareto[:, 1], color='black', alpha=0.18, lw=6, zorder=3,
393
- solid_joinstyle='round', solid_capstyle='round'
394
- )
395
- # plt.plot([88, 98, 98, 88, 88], [35, 35, 0.2, 0.2, 35], color='black', alpha=0.3, lw=1.5)
396
- # plt.xticks(fontsize=16)
397
- # plt.yticks(fontsize=16)
398
- # plt.xticks([(1+space) * (m-mlow) + 0.5 for m in ms], [f'm={m}' for m in ms])
399
- plt.title(title)
400
- plt.grid()
401
- plt.tight_layout(pad=0.4)
402
- plt.show()
403
-
404
- def plot_varpm_heat(task, name):
405
- def _get_score(m, l):
406
- fd = getpath('test_data', f'varpm-{task}', f'l{l}_m{m}')
407
- rewards, divs = [], []
408
- for i in range(5):
409
- reward, div = load_dict_json(f'{fd}/t{i+1}/performance.csv', 'reward', 'diversity')
410
- rewards.append(reward)
411
- divs.append(div)
412
- gmean = [sqrt(r * d) for r, d in zip(rewards, divs)]
413
- return np.mean(rewards), np.std(rewards), \
414
- np.mean(divs), np.std(divs), \
415
- np.mean(gmean), np.std(gmean)
416
-
417
- def _plot_map(avg_map, std_map, criterion):
418
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(9, 3), dpi=256, width_ratios=(1, 1))
419
- heat1 = ax1.imshow(avg_map, cmap='spring')
420
- heat2 = ax2.imshow(std_map, cmap='spring')
421
- ax1.set_xlim([-0.5, 5.5])
422
- ax1.set_xticks([0, 1, 2, 3, 4, 5], ['$\lambda$=0.0', '$\lambda$=0.1', '$\lambda$=0.2', '$\lambda$=0.3', '$\lambda$=0.4', '$\lambda$=0.5'])
423
- ax1.set_ylim([-0.5, 3.5])
424
- ax1.set_yticks([0, 1, 2, 3], ['m=5', 'm=4', 'm=3', 'm=2'])
425
- ax1.set_title('Average')
426
- for x, y in product([0, 1, 2, 3, 4, 5], [0, 1, 2, 3]):
427
- v = avg_map[y, x]
428
- s = '%.4f' % v
429
- if v >= 1000: s = s[:4]
430
- elif v >= 1: s = s[:5]
431
- else: s = s[1:6]
432
- ax1.text(x, y, s, va='center', ha='center')
433
- plt.colorbar(heat1, ax=ax1, shrink=0.9)
434
- ax2.set_xlim([-0.5, 5.5])
435
- ax2.set_xticks([0, 1, 2, 3, 4, 5], ['$\lambda$=0.0', '$\lambda$=0.1', '$\lambda$=0.2', '$\lambda$=0.3', '$\lambda$=0.4', '$\lambda$=0.5'])
436
- ax2.set_ylim([-0.5, 3.5])
437
- ax2.set_yticks([0, 1, 2, 3], ['m=5', 'm=4', 'm=3', 'm=2'])
438
- for x, y in product([0, 1, 2, 3, 4, 5], [0, 1, 2, 3]):
439
- v = std_map[y, x]
440
- s = '%.4f' % v
441
- if v >= 1000: s = s[:4]
442
- elif v >= 1: s = s[:5]
443
- else: s = s[1:6]
444
- ax2.text(x, y, s, va='center', ha='center')
445
- ax2.set_title('Standard Deviation')
446
- plt.colorbar(heat2, ax=ax2, shrink=0.9)
447
-
448
- fig.suptitle(f'{name}: {criterion}', fontsize=14)
449
- plt.tight_layout()
450
- # plt.show()
451
- plt.savefig(getpath(f'results/heat/{name}-{criterion}.png'))
452
-
453
- r_mean_map, r_std_map, d_mean_map, d_std_map, g_mean_map, g_std_map \
454
- = (np.zeros([4, 6], dtype=float) for _ in range(6))
455
- ms = [2, 3, 4, 5]
456
- ls = ['0.0', '0.1', '0.2', '0.3', '0.4', '0.5']
457
- for i, j in product(range(4), range(6)):
458
- r_mean, r_std, d_mean, d_std, g_mean, g_std = _get_score(ms[i], ls[j])
459
- r_mean_map[i, j] = r_mean
460
- r_std_map[i, j] = r_std
461
- d_mean_map[i, j] = d_mean
462
- d_std_map[i, j] = d_std
463
- g_mean_map[i, j] = g_mean
464
- g_std_map[i, j] = g_std
465
-
466
- _plot_map(r_mean_map, r_std_map, 'Reward')
467
- _plot_map(d_mean_map, d_std_map, 'Diversity')
468
- _plot_map(g_mean_map, g_std_map,'G-mean')
469
- # _plot_map(g_mean_map, g_std_map,'G-mean')
470
-
471
- def vis_samples():
472
- # for l, m in product(['0.0', '0.1', '0.2', '0.3', '0.4', '0.5'], [2, 3, 4, 5]):
473
- # for i in range(1, 6):
474
- # lvls = load_batch(f'{PRJROOT}/test_data/varpm-fhp/l{l}_m{m}/t{i}/samples.lvls')
475
- # imgs = [lvl.to_img(save_path=None) for lvl in lvls[:10]]
476
- # make_img_sheet(imgs, 1, save_path=f'{PRJROOT}/test_data/varpm-fhp/l{l}_m{m}/t{i}/samples.png')
477
- # for algo in ['sac', 'egsac', 'asyncsac', 'dvd', 'sunrise', 'pmoe']:
478
- # for i in range(1, 6):
479
- # lvls = load_batch(f'{PRJROOT}/test_data/{algo}/fhp/t{i}/samples.lvls')
480
- # imgs = [lvl.to_img(save_path=None) for lvl in lvls[:10]]
481
- # make_img_sheet(imgs, 1, save_path=f'{PRJROOT}/test_data/{algo}/fhp/t{i}/samples.png')
482
- for i in range(1, 6):
483
- lvls = load_batch(f'{PRJROOT}/test_data/DDPM-fhp/t{i}/samples.lvls')
484
- imgs = [lvl.to_img(save_path=None) for lvl in lvls[:10]]
485
- make_img_sheet(imgs, 1, save_path=f'{PRJROOT}/test_data/DDPM-fhp/t{i}/samples.png')
486
- pass
487
- pass
488
-
489
- def make_tsne(task, title, n=500, save_path=None):
490
- if not os.path.exists(getpath('test_data', f'samples_dist-{task}_{n}.npy')):
491
- samples = []
492
- for algo in ['dvd', 'egsac', 'pmoe', 'sunrise', 'asyncsac', 'sac']:
493
- for t in range(5):
494
- lvls = load_batch(getpath('test_data', algo, task, f't{t+1}', 'samples.lvls'))
495
- samples += lvls[:n]
496
- for l in ['0.0', '0.1', '0.2', '0.3', '0.4', '0.5']:
497
- for t in range(5):
498
- lvls = load_batch(getpath('test_data', f'varpm-{task}', f'l{l}_m5', f't{t+1}', 'samples.lvls'))
499
- samples += lvls[:n]
500
- distmat = []
501
- for a in samples:
502
- dist_list = []
503
- for b in samples:
504
- dist_list.append(hamming_dis(a, b))
505
- distmat.append(dist_list)
506
- distmat = np.array(distmat)
507
- np.save(getpath('test_data', f'samples_dist-{task}_{n}.npy'), distmat)
508
-
509
- labels = (
510
- '$\lambda$=0.0', '$\lambda$=0.1', '$\lambda$=0.2', '$\lambda$=0.3', '$\lambda$=0.4',
511
- '$\lambda$=0.5', 'DvD', 'EGSAC', 'PMOE', 'SUNRISE', 'ASAC', 'SAC'
512
- )
513
- tsne = TSNE(learning_rate='auto', n_components=2, metric='precomputed')
514
- print(np.load(getpath('test_data', f'samples_dist-{task}_{n}.npy')).shape)
515
- data = np.load(getpath('test_data', f'samples_dist-{task}_{n}.npy'))
516
- embx = np.array(tsne.fit_transform(data))
517
-
518
- plt.style.use('seaborn-dark-palette')
519
- plt.figure(figsize=(5, 5), dpi=384)
520
- colors = [plt.plot([-1000, -1100], [0, 0])[0].get_color() for _ in range(6)]
521
- for i in range(6):
522
- x, y = embx[i*n*5:(i+1)*n*5, 0], embx[i*n*5:(i+1)*n*5, 1]
523
- plt.scatter(x, y, s=10, label=labels[i], marker='x', c=colors[i])
524
- for i in range(6, 12):
525
- x, y = embx[i*n*5:(i+1)*n*5, 0], embx[i*n*5:(i+1)*n*5, 1]
526
- plt.scatter(x, y, s=8, linewidths=0, label=labels[i], c=colors[i-6])
527
- # plt.scatter(embx[100:200, 0], embx[100:200, 1], c=colors[1], s=12, linewidths=0, label='Killer')
528
- # plt.scatter(embx[200:, 0], embx[200:, 1], c=colors[2], s=12, linewidths=0, label='Collector')
529
- # for i in range(4):
530
- # plt.text(embx[i+100, 0], embx[i+100, 1], str(i+1))
531
- # plt.text(embx[i+200, 0], embx[i+200, 1], str(i+1))
532
- # pass
533
- # for emb, lb, c in zip(embs, labels,colors):
534
- # plt.scatter(emb[:,0], emb[:,1], c=c, label=lb, alpha=0.15, linewidths=0, s=7)
535
-
536
- # xspan = 1.08 * max(abs(embx[:, 0].max()), abs(embx[:, 0].min()))
537
- # yspan = 1.08 * max(abs(embx[:, 1].max()), abs(embx[:, 1].min()))
538
-
539
- xrange = [1.05 * embx[:, 0].min(), 1.05 * embx[:, 0].max()]
540
- yrange = [1.05 * embx[:, 1].min(), 1.25 * embx[:, 1].max()]
541
-
542
- plt.xlim(xrange)
543
- plt.ylim(yrange)
544
- plt.xticks([])
545
- plt.yticks([])
546
- # plt.legend(ncol=6, handletextpad=0.02, labelspacing=0.05, columnspacing=0.16)
547
- # plt.xticks([-xspan, -0.5 * xspan, 0, 0.5 * xspan, xspan], [''] * 5)
548
- # plt.yticks([-yspan, -0.5 * yspan, 0, 0.6 * yspan, yspan], [''] * 5)
549
- plt.title(title)
550
- plt.legend(loc='upper center', ncol=6, fontsize=9, handlelength=.5, handletextpad=0.5, columnspacing=0.3, framealpha=0.)
551
- plt.tight_layout(pad=0.2)
552
-
553
- if save_path:
554
- plt.savefig(getpath(save_path))
555
- else:
556
- plt.show()
557
-
558
- def _prob_fmt(p, digitals=3, threshold=0.001):
559
- fmt = '%.' + str(digitals) + 'f'
560
- if p < threshold:
561
- return '$\\approx 0$'
562
- else:
563
- txt = '$%s$' % ((fmt % p)[1:])
564
- if txt == '$.000$':
565
- txt = '$1.00$'
566
- return txt
567
-
568
- def _g_fmt(v, digitals=4):
569
- fmt = '%.' + str(digitals) + 'g'
570
- txt = (fmt % v)
571
- lack = digitals - len(txt.replace('-', '').replace('.', ''))
572
- if lack > 0 and '.' not in txt:
573
- txt += '.'
574
- return txt + '0' * lack
575
- pass
576
-
577
- def print_selection_prob(path, h=15, runs=2):
578
- s0 = 0
579
- model = torch.load(getpath(f'{path}/policy.pth'), map_location='cpu')
580
- model.requires_grad_(False)
581
- model.to('cpu')
582
- n = 11
583
- # n = load_cfgs(path, 'N')
584
- # print(model.m)
585
-
586
- init_vec = np.load(getpath('analysis/initial_seg.npy'))[s0]
587
- decoder = get_decoder(device='cpu')
588
- obs_buffer = RingQueue(n)
589
- for r in range(runs):
590
- for _ in range(h): obs_buffer.push(np.zeros([nz]))
591
- obs_buffer.push(init_vec)
592
- level_latvecs = [init_vec]
593
- probs = np.zeros([model.m, h])
594
- # probs = []
595
- selects = []
596
- for t in range(h):
597
- # probs.append([])
598
- obs = torch.tensor(np.concatenate(obs_buffer.to_list(), axis=-1), dtype=torch.float).view([1, -1])
599
- muss, stdss, betas = model.get_intermediate(torch.tensor(obs))
600
- i = torch.multinomial(betas.squeeze(), 1).item()
601
- # print(i)
602
- mu, std = muss[0][i], stdss[0][i]
603
- action = Normal(mu, std).rsample([1]).squeeze().numpy()
604
- # print(action)
605
- # print(mu)
606
- # print(std)
607
- # print(action.numpy())
608
- obs_buffer.push(action)
609
- level_latvecs.append(action)
610
- # i = torch.multinomial(betas.squeeze(), 1).item()
611
- # print(i)
612
- probs[:, t] = betas.squeeze().numpy()
613
- selects.append(i)
614
- pass
615
- onehots = decoder(torch.tensor(level_latvecs).view(-1, nz, 1, 1))
616
- segs = process_onehot(onehots)
617
- lvl = lvlhcat(segs)
618
- lvl.to_img(f'figures/gen_process/run{r}-01.png')
619
- txts = [[_prob_fmt(p) for p in row] for row in probs]
620
- for t, i in enumerate(selects):
621
- txts[i][t] = r'$\boldsymbol{%s}$' % txts[i][t][1:-1]
622
- for i, txt in enumerate(txts):
623
- print(f' & $\\beta_{i+1}$ &', ' & '.join(txt), r'\\')
624
- print(r'\midrule')
625
-
626
- pass
627
-
628
- def calc_selection_freqs(task, n):
629
- def _count_one_init():
630
- counts = np.zeros([model.m])
631
- # init_vec = np.load(getpath('analysis/initial_seg.npy'))
632
- obs_buffer = RingQueue(n)
633
- for _ in range(runs):
634
- for _ in range(h): obs_buffer.push(np.zeros([len(init_vecs), nz]))
635
- obs_buffer.push(init_vecs)
636
- # level_latvecs = [init_vec]
637
- for _ in range(h):
638
- obs = np.concatenate(obs_buffer.to_list(), axis=-1)
639
- obs = torch.tensor(obs, device='cuda:0', dtype=torch.float)
640
- muss, stdss, betas = model.get_intermediate(obs)
641
- selects = torch.multinomial(betas.squeeze(), 1).squeeze()
642
- mus = muss[[*range(len(init_vecs))], selects, :]
643
- stds = stdss[[*range(len(init_vecs))], selects, :]
644
- actions = Normal(mus, stds).rsample().squeeze().cpu().numpy()
645
- obs_buffer.push(actions)
646
- for i in selects:
647
- counts[i] = counts[i] + 1
648
- return counts
649
- # onehots = decoder(torch.tensor(level_latvecs).view(-1, nz, 1, 1))
650
- pass
651
- pass
652
- init_vecs = np.load(getpath('analysis/initial_seg.npy'))
653
- freqs = [[] for _ in range(30)]
654
- start_line = 0
655
- for l in ('0.0', '0.1', '0.2', '0.3', '0.4', '0.5'):
656
- print(r' \midrule')
657
- for t, m in product(range(1, 6), (2, 3, 4, 5)):
658
- path = getpath(f'test_data/varpm-{task}/l{l}_m{m}/t{t}')
659
- model = torch.load(getpath(f'{path}/policy.pth'), map_location='cuda:0')
660
- model.requires_grad_(False)
661
- freq = np.zeros([m])
662
- # n = load_cfgs(path, 'N')
663
- runs, h = 100, 25
664
- freq += _count_one_init()
665
- freq /= (len(init_vecs) * runs * h)
666
- freq = np.sort(freq)[::-1]
667
- i = start_line + t - 1
668
- freqs[i] += freq.tolist()
669
- print(freqs[i])
670
- start_line += 5
671
- print(freqs)
672
- with open(getpath(f'analysis/select_freqs-{task}.json'), 'w') as f:
673
- json.dump(freqs, f)
674
-
675
- def print_selection_freq():
676
- # task, n = 'lgp', 5
677
- task, n = 'fhp', 11
678
- if not os.path.exists(getpath(f'analysis/select_freqs-{task}.json')):
679
- calc_selection_freqs(task, n)
680
- with open(getpath(f'analysis/select_freqs-{task}.json'), 'r') as f:
681
- freqs = json.load(f)
682
- lbds = ['0.0', '0.1', '0.2', '0.3', '0.4', '0.5']
683
- for i, row_data in enumerate(freqs):
684
- if i % 5 == 0:
685
- print(r' \midrule')
686
- print(r' \multirow{5}{*}{$%s$}' % lbds[i//5])
687
- txt = ' & '.join(map(_prob_fmt, row_data))
688
- print(f' & {i%5+1} &', txt, r'\\')
689
-
690
- def print_individual_performances(task):
691
- for m, l in product((2, 3, 4, 5), ('0.0', '0.1', '0.2', '0.3', '0.4', '0.5')):
692
- values = []
693
- if l == '0.0':
694
- print(r' \midrule')
695
- print(r' \multirow{6}{*}{%d}' % m)
696
- for t in range(1, 6):
697
- path = f'test_data/varpm-{task}/l{l}_m{m}/t{t}/performance.csv'
698
- reward, diversity = load_dict_json(path, 'reward', 'diversity')
699
- values.append([reward, diversity])
700
- values.sort(key=lambda item: -item[0])
701
- values = [*chain(*values)]
702
- txts = [_g_fmt(v) for v in values]
703
- print(' &', f'${l}$ & ', ' & '.join(txts), r'\\')
704
- pass
705
-
706
- if __name__ == '__main__':
707
- # print_selection_prob('test_data/varpm-fhp/l0.5_m5/t5')
708
- # print_selection_prob('test_data/varpm-fhp/l0.1_m5/t5')
709
- # print_selection_freq()
710
- # print_compare_tab_nonrl()
711
- # print_individual_performances('fhp')
712
- # print('\n\n')
713
- # print_individual_performances('lgp')
714
-
715
- # plot_cmp_learning_curves('fhp', save_path='results/learning_curves/fhp.png', title='MarioPuzzle')
716
- # plot_cmp_learning_curves('lgp', save_path='results/learning_curves/lgp.png', title='MultiFacet')
717
-
718
- # plot_crosstest_scatters('fhp', title='MarioPuzzle')
719
- # plot_crosstest_scatters('lgp', title='MultiFacet')
720
- # # plot_crosstest_scatters('fhp', yrange=(0, 2500), xrange=(20, 70), title='MarioPuzzle')
721
- # plot_crosstest_scatters('lgp', yrange=(0, 1500), xrange=(20, 50), title='MultiFacet')
722
- # plot_crosstest_scatters('lgp', yrange=(0, 800), xrange=(44, 48), title=' ')
723
-
724
-
725
- # plot_varpm_heat('fhp', 'MarioPuzzle')
726
- # plot_varpm_heat('lgp', 'MultiFacet')
727
-
728
- vis_samples()
729
-
730
- # make_tsne('fhp', 'MarioPuzzle', n=100)
731
- # make_tsne('lgp', 'MultiFacet', n=100)
732
- pass
733
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pyproject.toml DELETED
@@ -1,21 +0,0 @@
1
- [tool.poetry]
2
- name = "ncerl"
3
- version = "0.1.0"
4
- description = ""
5
- authors = ["Ziqi Wang"]
6
- readme = "README.md"
7
-
8
- [tool.poetry.dependencies]
9
- python = "^3.9"
10
- JPype1 = "1.3.0"
11
- dtw = "1.4.0"
12
- torch = "1.8.1"
13
- numpy = "^2.0.0"
14
- pillow = "10.0.0"
15
- matplotlib = "3.6.3"
16
- pandas = "1.3.2"
17
-
18
-
19
- [build-system]
20
- requires = ["poetry-core"]
21
- build-backend = "poetry.core.masonry.api"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
Binary files a/requirements.txt and b/requirements.txt differ
 
src/drl/egsac/train_egsac.py CHANGED
@@ -71,6 +71,7 @@ def train_EGSAC(args):
71
  return
72
  device = 'cpu' if args.gpuid < 0 or not torch.cuda.is_available() else f'cuda:{args.gpuid}'
73
 
 
74
  rfunc = importlib.import_module('src.env.rfuncs').__getattribute__(f'{args.rfunc_name}')()
75
  with open(res_path + '/run_config.txt', 'w') as f:
76
  f.write(time.strftime('%Y-%m-%d %H:%M') + '\n')
@@ -83,8 +84,7 @@ def train_EGSAC(args):
83
  f.write('-' * 50 + '\n')
84
  f.write(str(rfunc))
85
  hist_len = rfunc.get_n()
86
- # with open(f'{res_path}/hist_len.json', 'w') as f:
87
- # json.dump(hist_len, f)
88
  with open(f'{res_path}/cfgs.json', 'w') as f:
89
  data = {'N': hist_len, 'gamma': args.gamma, 'h': args.eplen, 'rfunc': args.rfunc_name}
90
  json.dump(data, f)
 
71
  return
72
  device = 'cpu' if args.gpuid < 0 or not torch.cuda.is_available() else f'cuda:{args.gpuid}'
73
 
74
+ # 动态导入reward function
75
  rfunc = importlib.import_module('src.env.rfuncs').__getattribute__(f'{args.rfunc_name}')()
76
  with open(res_path + '/run_config.txt', 'w') as f:
77
  f.write(time.strftime('%Y-%m-%d %H:%M') + '\n')
 
84
  f.write('-' * 50 + '\n')
85
  f.write(str(rfunc))
86
  hist_len = rfunc.get_n()
87
+
 
88
  with open(f'{res_path}/cfgs.json', 'w') as f:
89
  data = {'N': hist_len, 'gamma': args.gamma, 'h': args.eplen, 'rfunc': args.rfunc_name}
90
  json.dump(data, f)
src/drl/sunrise/train_sunrise.py CHANGED
@@ -141,7 +141,6 @@ def get_trainer(args, obs_dim, action_dim, path, device):
141
 
142
 
143
  trainer = NeurIPS20SACEnsembleTrainer(
144
- # env=eval_env,
145
  policy=L_policy,
146
  qf1=L_qf1,
147
  qf2=L_qf2,
@@ -159,7 +158,6 @@ def get_trainer(args, obs_dim, action_dim, path, device):
159
  **variant['trainer_kwargs']
160
  )
161
  return trainer
162
- pass
163
 
164
  def get_algo(args, rfunc, device, path):
165
  algorithm = AsyncOffPolicyALgo(
@@ -178,6 +176,7 @@ def get_algo(args, rfunc, device, path):
178
  return algorithm
179
 
180
  def train_SUNRISE(args):
 
181
  if not args.path:
182
  path = auto_dire('training_data', args.name)
183
  else:
@@ -192,6 +191,8 @@ def train_SUNRISE(args):
192
  print(f'Trainning at <{path}> is skipped as there has a finished trial already.')
193
  return
194
  device = 'cpu' if args.gpuid < 0 or not torch.cuda.is_available() else f'cuda:{args.gpuid}'
 
 
195
  rfunc = importlib.import_module('src.env.rfuncs').__getattribute__(f'{args.rfunc}')()
196
 
197
  with open(path + '/run_configuration.txt', 'w') as f:
@@ -217,9 +218,3 @@ def train_SUNRISE(args):
217
  trainer = get_trainer(args, obs_dim, action_dim, path, device)
218
  algorithm = get_algo(args, rfunc, device, path)
219
  _, timecost = record_time(algorithm.train)(env, trainer, args.budget, args.inference_type, path)
220
-
221
- pass
222
-
223
-
224
- if __name__ == '__main__':
225
- pass
 
141
 
142
 
143
  trainer = NeurIPS20SACEnsembleTrainer(
 
144
  policy=L_policy,
145
  qf1=L_qf1,
146
  qf2=L_qf2,
 
158
  **variant['trainer_kwargs']
159
  )
160
  return trainer
 
161
 
162
  def get_algo(args, rfunc, device, path):
163
  algorithm = AsyncOffPolicyALgo(
 
176
  return algorithm
177
 
178
  def train_SUNRISE(args):
179
+ # 创建目录
180
  if not args.path:
181
  path = auto_dire('training_data', args.name)
182
  else:
 
191
  print(f'Trainning at <{path}> is skipped as there has a finished trial already.')
192
  return
193
  device = 'cpu' if args.gpuid < 0 or not torch.cuda.is_available() else f'cuda:{args.gpuid}'
194
+
195
+ # 导入reward function
196
  rfunc = importlib.import_module('src.env.rfuncs').__getattribute__(f'{args.rfunc}')()
197
 
198
  with open(path + '/run_configuration.txt', 'w') as f:
 
218
  trainer = get_trainer(args, obs_dim, action_dim, path, device)
219
  algorithm = get_algo(args, rfunc, device, path)
220
  _, timecost = record_time(algorithm.train)(env, trainer, args.budget, args.inference_type, path)
 
 
 
 
 
 
src/drl/train_async.py CHANGED
@@ -36,6 +36,9 @@ def set_common_args(parser):
36
  )
37
 
38
  def drl_train(foo):
 
 
 
39
  def __inner(args):
40
  if not args.path:
41
  path = auto_dire('training_data', args.name)
@@ -74,6 +77,7 @@ def drl_train(foo):
74
  json.dump(data, f)
75
  obs_dim, act_dim = env.histlen * nz, nz
76
 
 
77
  agent = foo(args, path, device, obs_dim, act_dim)
78
 
79
  agent.to(device)
@@ -89,6 +93,7 @@ def set_AsyncSAC_parser(parser):
89
  set_common_args(parser)
90
  parser.add_argument('--name', type=str, default='AsyncSAC', help='Name of this algorithm.')
91
 
 
92
  @drl_train
93
  def train_AsyncSAC(args, path, device, obs_dim, act_dim):
94
  actor = SoftActor(
@@ -116,12 +121,16 @@ def set_NCESAC_parser(parser):
116
  @drl_train
117
  def train_NCESAC(args, path, device, obs_dim, act_dim):
118
  me_reg, actor_nn_constructor = None, None
 
 
119
  if args.me_type == 'log':
120
  me_reg = LogWassersteinExclusion(args.lbd)
121
  elif args.me_type == 'clip':
122
  me_reg = ClipExclusion(args.lbd)
123
  elif args.me_type == 'logclip':
124
  me_reg = LogClipExclusion(args.lbd)
 
 
125
  if args.actor_net_type == 'conv':
126
  actor_nn_constructor = lambda: EsmbGaussianConv(
127
  obs_dim, act_dim, args.actor_hiddens, args.actor_hiddens, args.m
@@ -130,7 +139,11 @@ def train_NCESAC(args, path, device, obs_dim, act_dim):
130
  actor_nn_constructor = lambda: EsmbGaussianMLP(
131
  obs_dim, act_dim, args.actor_hiddens, args.actor_hiddens, args.m
132
  )
 
 
133
  actor = MERegMixSoftActor(actor_nn_constructor, me_reg, tar_ent=args.tar_entropy)
 
 
134
  critic = MERegSoftDoubleClipCriticQ(
135
  lambda : ObsActMLP(obs_dim, act_dim, args.critic_hiddens),
136
  gamma=args.gamma, tau=args.tau
@@ -139,6 +152,8 @@ def train_NCESAC(args, path, device, obs_dim, act_dim):
139
  lambda : ObsActMLP(obs_dim, act_dim, args.critic_hiddens),
140
  gamma=args.gamma, tau=args.tau
141
  )
 
 
142
  with open(f'{path}/nn_architecture.txt', 'w') as f:
143
  f.writelines([
144
  '-' * 24 + 'Actor' + '-' * 24 + '\n', actor.get_nn_arch_str(),
 
36
  )
37
 
38
  def drl_train(foo):
39
+ """
40
+ DRL Train, foo是被调用的函数, 如train_AsyncSAC.
41
+ """
42
  def __inner(args):
43
  if not args.path:
44
  path = auto_dire('training_data', args.name)
 
77
  json.dump(data, f)
78
  obs_dim, act_dim = env.histlen * nz, nz
79
 
80
+ # 根据foo的不同返回agent, 返回的类型是ActCrtAgent
81
  agent = foo(args, path, device, obs_dim, act_dim)
82
 
83
  agent.to(device)
 
93
  set_common_args(parser)
94
  parser.add_argument('--name', type=str, default='AsyncSAC', help='Name of this algorithm.')
95
 
96
+ #同样的sac训练,但是多了异步
97
  @drl_train
98
  def train_AsyncSAC(args, path, device, obs_dim, act_dim):
99
  actor = SoftActor(
 
121
  @drl_train
122
  def train_NCESAC(args, path, device, obs_dim, act_dim):
123
  me_reg, actor_nn_constructor = None, None
124
+
125
+ # 初始化不同的正则化器
126
  if args.me_type == 'log':
127
  me_reg = LogWassersteinExclusion(args.lbd)
128
  elif args.me_type == 'clip':
129
  me_reg = ClipExclusion(args.lbd)
130
  elif args.me_type == 'logclip':
131
  me_reg = LogClipExclusion(args.lbd)
132
+
133
+ # 初始化不同的 网络构造器
134
  if args.actor_net_type == 'conv':
135
  actor_nn_constructor = lambda: EsmbGaussianConv(
136
  obs_dim, act_dim, args.actor_hiddens, args.actor_hiddens, args.m
 
139
  actor_nn_constructor = lambda: EsmbGaussianMLP(
140
  obs_dim, act_dim, args.actor_hiddens, args.actor_hiddens, args.m
141
  )
142
+
143
+ # 初始化Actor
144
  actor = MERegMixSoftActor(actor_nn_constructor, me_reg, tar_ent=args.tar_entropy)
145
+
146
+ # 初始化Critic
147
  critic = MERegSoftDoubleClipCriticQ(
148
  lambda : ObsActMLP(obs_dim, act_dim, args.critic_hiddens),
149
  gamma=args.gamma, tau=args.tau
 
152
  lambda : ObsActMLP(obs_dim, act_dim, args.critic_hiddens),
153
  gamma=args.gamma, tau=args.tau
154
  )
155
+
156
+ # 保存神经网络架构
157
  with open(f'{path}/nn_architecture.txt', 'w') as f:
158
  f.writelines([
159
  '-' * 24 + 'Actor' + '-' * 24 + '\n', actor.get_nn_arch_str(),
src/drl/train_sinproc.py CHANGED
@@ -30,24 +30,30 @@ def set_common_args(parser):
30
 
31
  def drl_train(foo):
32
  def __inner(args):
 
33
  if not args.path:
34
  path = auto_dire('training_data', args.name)
35
  else:
36
  path = getpath('training_data', args.path)
37
  os.makedirs(path, exist_ok=True)
 
38
  if os.path.exists(f'{path}/policy.pth'):
39
  print(f'Trainning at <{path}> is skipped as there has a finished trial already.')
40
  return
41
  device = 'cpu' if args.gpuid < 0 or not torch.cuda.is_available() else f'cuda:{args.gpuid}'
42
 
 
43
  rfunc = importlib.import_module('src.env.rfuncs').__getattribute__(f'{args.rfunc}')()
44
  env = SingleProcessOLGenEnv(rfunc, get_decoder('models/decoder.pth'), args.eplen, device=device)
 
 
45
  loggers = [
46
  AsyncCsvLogger(f'{path}/log.csv', rfunc),
47
  AsyncStdLogger(rfunc, 2000, f'{path}/log.txt' if args.redirect else '')
48
  ]
49
  if args.periodic_gen_num > 0:
50
  loggers.append(GenResLogger(path, args.periodic_gen_num, args.gen_period))
 
51
  with open(path + '/run_configuration.txt', 'w') as f:
52
  f.write(time.strftime('%Y-%m-%d %H:%M') + '\n')
53
  f.write(f'---------{args.name}---------\n')
@@ -59,14 +65,18 @@ def drl_train(foo):
59
  f.write('-' * 50 + '\n')
60
  f.write(str(rfunc))
61
  N = rfunc.get_n()
 
62
  with open(f'{path}/cfgs.json', 'w') as f:
63
  data = {'N': N, 'gamma': args.gamma, 'h': args.eplen, 'rfunc': args.rfunc}
64
  json.dump(data, f)
 
65
  obs_dim, act_dim = env.hist_len * nz, nz
66
 
 
67
  agent = foo(args, path, device, obs_dim, act_dim)
68
 
69
  agent.to(device)
 
70
  trainer = SinProcOffpolicyTrainer(
71
  ReplayMem(args.mem_size, device=device), update_per=args.update_per, batch=args.batch
72
  )
@@ -76,10 +86,12 @@ def drl_train(foo):
76
  return __inner
77
 
78
  ############### SAC ###############
 
79
  def set_SAC_parser(parser):
80
  set_common_args(parser)
81
  parser.add_argument('--name', type=str, default='SAC', help='Name of this algorithm.')
82
 
 
83
  @drl_train
84
  def train_SAC(args, path, device, obs_dim, act_dim):
85
  actor = SoftActor(
@@ -88,6 +100,7 @@ def train_SAC(args, path, device, obs_dim, act_dim):
88
  critic = SoftDoubleClipCriticQ(
89
  lambda : ObsActMLP(obs_dim, act_dim, args.critic_hiddens), gamma=args.gamma, tau=args.tau
90
  )
 
91
  with open(f'{path}/nn_architecture.txt', 'w') as f:
92
  f.writelines([
93
  '-' * 24 + 'Actor' + '-' * 24 + '\n', actor.get_nn_arch_str(),
 
30
 
31
  def drl_train(foo):
32
  def __inner(args):
33
+ # 设置保存路径
34
  if not args.path:
35
  path = auto_dire('training_data', args.name)
36
  else:
37
  path = getpath('training_data', args.path)
38
  os.makedirs(path, exist_ok=True)
39
+ # 检查是否已经存在训练完成的模型
40
  if os.path.exists(f'{path}/policy.pth'):
41
  print(f'Trainning at <{path}> is skipped as there has a finished trial already.')
42
  return
43
  device = 'cpu' if args.gpuid < 0 or not torch.cuda.is_available() else f'cuda:{args.gpuid}'
44
 
45
+ # 导入reward函数
46
  rfunc = importlib.import_module('src.env.rfuncs').__getattribute__(f'{args.rfunc}')()
47
  env = SingleProcessOLGenEnv(rfunc, get_decoder('models/decoder.pth'), args.eplen, device=device)
48
+
49
+ # 设置日志记录器
50
  loggers = [
51
  AsyncCsvLogger(f'{path}/log.csv', rfunc),
52
  AsyncStdLogger(rfunc, 2000, f'{path}/log.txt' if args.redirect else '')
53
  ]
54
  if args.periodic_gen_num > 0:
55
  loggers.append(GenResLogger(path, args.periodic_gen_num, args.gen_period))
56
+ # 保存运行配置
57
  with open(path + '/run_configuration.txt', 'w') as f:
58
  f.write(time.strftime('%Y-%m-%d %H:%M') + '\n')
59
  f.write(f'---------{args.name}---------\n')
 
65
  f.write('-' * 50 + '\n')
66
  f.write(str(rfunc))
67
  N = rfunc.get_n()
68
+ # 保存配置文件
69
  with open(f'{path}/cfgs.json', 'w') as f:
70
  data = {'N': N, 'gamma': args.gamma, 'h': args.eplen, 'rfunc': args.rfunc}
71
  json.dump(data, f)
72
+ # 设置观察和动作维度
73
  obs_dim, act_dim = env.hist_len * nz, nz
74
 
75
+ # 创建代理
76
  agent = foo(args, path, device, obs_dim, act_dim)
77
 
78
  agent.to(device)
79
+ # 创建训练器
80
  trainer = SinProcOffpolicyTrainer(
81
  ReplayMem(args.mem_size, device=device), update_per=args.update_per, batch=args.batch
82
  )
 
86
  return __inner
87
 
88
  ############### SAC ###############
89
+ # 设置SAC参数的函数
90
  def set_SAC_parser(parser):
91
  set_common_args(parser)
92
  parser.add_argument('--name', type=str, default='SAC', help='Name of this algorithm.')
93
 
94
+ # SAC训练函数
95
  @drl_train
96
  def train_SAC(args, path, device, obs_dim, act_dim):
97
  actor = SoftActor(
 
100
  critic = SoftDoubleClipCriticQ(
101
  lambda : ObsActMLP(obs_dim, act_dim, args.critic_hiddens), gamma=args.gamma, tau=args.tau
102
  )
103
+ # 保存神经网络架构
104
  with open(f'{path}/nn_architecture.txt', 'w') as f:
105
  f.writelines([
106
  '-' * 24 + 'Actor' + '-' * 24 + '\n', actor.get_nn_arch_str(),
src/env/environments.py CHANGED
@@ -43,7 +43,7 @@ class SingleProcessOLGenEnv(gym.Env):
43
  self.device = device
44
  self.action_space = gym.spaces.Box(-1, 1, (nz,))
45
  self.observation_space = gym.spaces.Box(-1, 1, (self.hist_len * nz,))
46
- # self.obs_queue = RingQueue(self.hist_len)
47
  self.lat_vecs = []
48
  self.simulator = MarioProxy()
49
  pass
@@ -62,7 +62,6 @@ class SingleProcessOLGenEnv(gym.Env):
62
 
63
  def __evalute(self):
64
  z = torch.tensor(np.stack(self.lat_vecs).reshape([-1, nz, 1, 1]), device=self.device, dtype=torch.float)
65
- # print(z.shape)
66
  segs = process_onehot(self.decoder(z))
67
  lvl = lvlhcat(segs)
68
  simlt_res = MarioProxy.get_seg_infos(self.simulator.simulate_complete(lvl))
@@ -123,6 +122,7 @@ class AsyncOlGenEnv:
123
  self.decoder = decoder
124
  self.decoder.to(device)
125
  self.device = device
 
126
  self.eval_pool = eval_pool
127
  self.eplen = eplen
128
  self.tid = 0
@@ -250,7 +250,6 @@ class SyncOLGenWorkerEnv(gym.Env):
250
  done = self.counter >= self.eplen
251
  if done:
252
  full_level = lvlhcat(self.segs)
253
- # full_level = self.repairer.repair(full_level)
254
  w = MarioLevel.seg_width
255
  segs = [full_level[:, s: s + w] for s in range(0, full_level.w, w)]
256
  if self.mario_proxy:
@@ -352,14 +351,12 @@ class VecOLGenEnv(SubprocVecEnv):
352
  target_remotes = self._get_target_remotes(env_ids)
353
 
354
  n_inits = 1 if self.init_one else self.hist_len
355
- # latvecs = [sample_latvec(n_inits, tensor=False) for _ in range(len(env_ids))]
356
 
357
  latvecs = [self.latvec_set[random.sample(range(len(self.latvec_set)), n_inits)] for _ in range(len(env_ids))]
358
  with torch.no_grad():
359
  segss = [[] for _ in range(len(env_ids))]
360
  for i in range(len(env_ids)):
361
  z = torch.tensor(latvecs[i]).view(-1, nz, 1, 1).to(self.device)
362
- # print(self.decoder(z).shape)
363
  segss[i] = [process_onehot(self.decoder(z))] if self.init_one else process_onehot(self.decoder(z))
364
  for remote, latvec, segs in zip(target_remotes, latvecs, segss):
365
  kwargs = {'backup_latvecs': latvec, 'backup_strsegs': [str(seg) for seg in segs]}
 
43
  self.device = device
44
  self.action_space = gym.spaces.Box(-1, 1, (nz,))
45
  self.observation_space = gym.spaces.Box(-1, 1, (self.hist_len * nz,))
46
+
47
  self.lat_vecs = []
48
  self.simulator = MarioProxy()
49
  pass
 
62
 
63
  def __evalute(self):
64
  z = torch.tensor(np.stack(self.lat_vecs).reshape([-1, nz, 1, 1]), device=self.device, dtype=torch.float)
 
65
  segs = process_onehot(self.decoder(z))
66
  lvl = lvlhcat(segs)
67
  simlt_res = MarioProxy.get_seg_infos(self.simulator.simulate_complete(lvl))
 
122
  self.decoder = decoder
123
  self.decoder.to(device)
124
  self.device = device
125
+ # mario simulator 在eval_pool里面
126
  self.eval_pool = eval_pool
127
  self.eplen = eplen
128
  self.tid = 0
 
250
  done = self.counter >= self.eplen
251
  if done:
252
  full_level = lvlhcat(self.segs)
 
253
  w = MarioLevel.seg_width
254
  segs = [full_level[:, s: s + w] for s in range(0, full_level.w, w)]
255
  if self.mario_proxy:
 
351
  target_remotes = self._get_target_remotes(env_ids)
352
 
353
  n_inits = 1 if self.init_one else self.hist_len
 
354
 
355
  latvecs = [self.latvec_set[random.sample(range(len(self.latvec_set)), n_inits)] for _ in range(len(env_ids))]
356
  with torch.no_grad():
357
  segss = [[] for _ in range(len(env_ids))]
358
  for i in range(len(env_ids)):
359
  z = torch.tensor(latvecs[i]).view(-1, nz, 1, 1).to(self.device)
 
360
  segss[i] = [process_onehot(self.decoder(z))] if self.init_one else process_onehot(self.decoder(z))
361
  for remote, latvec, segs in zip(target_remotes, latvecs, segss):
362
  kwargs = {'backup_latvecs': latvec, 'backup_strsegs': [str(seg) for seg in segs]}
src/env/rfunc.py CHANGED
@@ -44,6 +44,9 @@ class RewardTerm:
44
 
45
 
46
  class Playability(RewardTerm):
 
 
 
47
  def __init__(self, magnitude=1):
48
  super(Playability, self).__init__(True)
49
  self.magnitude=magnitude
@@ -57,6 +60,9 @@ class Playability(RewardTerm):
57
 
58
 
59
  class MeanDivergenceFun(RewardTerm):
 
 
 
60
  def __init__(self, goal_div, n=defaults['n'], s=8):
61
  super().__init__(False)
62
  self.l = goal_div * 0.26 / 0.6
@@ -74,7 +80,6 @@ class MeanDivergenceFun(RewardTerm):
74
  divergences = []
75
  while k * self.s <= (min(self.n, i) - 1) * MarioLevel.seg_width:
76
  cmp_seg = histroy[:, k * self.s: k * self.s + MarioLevel.seg_width]
77
- # print(i, nd, cmp_seg.shape)
78
  divergences.append(tile_pattern_js_div(seg, cmp_seg))
79
  k += 1
80
  mean_d = sum(divergences) / len(divergences)
@@ -211,9 +216,5 @@ class HistoricalDeviation(RewardTerm):
211
 
212
 
213
  if __name__ == '__main__':
214
- # print(type(ceil(0.2)))
215
- # arr = [1., 3., 2.]
216
- # arr.sort()
217
- # print(arr)
218
  rfunc = HistoricalDeviation()
219
 
 
44
 
45
 
46
  class Playability(RewardTerm):
47
+ """
48
+ 可玩性
49
+ """
50
  def __init__(self, magnitude=1):
51
  super(Playability, self).__init__(True)
52
  self.magnitude=magnitude
 
60
 
61
 
62
  class MeanDivergenceFun(RewardTerm):
63
+ """
64
+ 多样性
65
+ """
66
  def __init__(self, goal_div, n=defaults['n'], s=8):
67
  super().__init__(False)
68
  self.l = goal_div * 0.26 / 0.6
 
80
  divergences = []
81
  while k * self.s <= (min(self.n, i) - 1) * MarioLevel.seg_width:
82
  cmp_seg = histroy[:, k * self.s: k * self.s + MarioLevel.seg_width]
 
83
  divergences.append(tile_pattern_js_div(seg, cmp_seg))
84
  k += 1
85
  mean_d = sum(divergences) / len(divergences)
 
216
 
217
 
218
  if __name__ == '__main__':
 
 
 
 
219
  rfunc = HistoricalDeviation()
220
 
src/gan/adversarial_train.py CHANGED
@@ -87,7 +87,6 @@ def train_GAN(args):
87
  w = csv.writer(f)
88
  w.writerow(['key', 'value', ''])
89
  w.writerows(list(cfgs.items()))
90
- # pds.DataFrame.from_dict(cfgs, orient='index', columns=['value']).to_csv(f'{path_}/cfgs.csv')
91
 
92
  start_time = time.time()
93
  log_target = open(f'{res_path}/logs.csv', 'w')
@@ -120,26 +119,6 @@ def train_GAN(args):
120
  loss_G = -netD(fake).mean()
121
  loss_G.backward()
122
  optG.step()
123
- # # Evaluate
124
- # if t % args.eval_itv == (args.eval_itv - 1):
125
- # netG.eval()
126
- # netD.eval()
127
- # with torch.no_grad():
128
- # real = torch.stack(data[:min(100, len(data))])
129
- # z = sample_latvec(100, device=device, distribuion=args.noise)
130
- # fake = netG(z)
131
- # y_real = netD(real).mean().item()
132
- # y_fake = netD(fake).mean().item()
133
- # # hamming_divs, tpjs_divs = evaluate_diversity(process_onehot(fake))
134
- #
135
- # # items = (t+1, y_real, y_fake, hamming_divs, tpjs_divs, time.time() - start_time)
136
- # # log_writer.writerow(items)
137
- # print(
138
- # 'Iteration %d, y-real=%.3g, y-fake=%.3g, Hamming-divs: %.5g, TPJS-divs: %.5g, '
139
- # 'time: %.1fs' % items
140
- # )
141
- # netD.train()
142
- # netG.train()
143
  if t % args.save_itv == (args.save_itv - 1):
144
  netG.eval()
145
  netD.eval()
 
87
  w = csv.writer(f)
88
  w.writerow(['key', 'value', ''])
89
  w.writerows(list(cfgs.items()))
 
90
 
91
  start_time = time.time()
92
  log_target = open(f'{res_path}/logs.csv', 'w')
 
119
  loss_G = -netD(fake).mean()
120
  loss_G.backward()
121
  optG.step()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  if t % args.save_itv == (args.save_itv - 1):
123
  netG.eval()
124
  netD.eval()
src/gan/gankits.py CHANGED
@@ -3,7 +3,7 @@ from src.smb.level import MarioLevel
3
  from src.gan.gans import nz
4
  from src.utils.filesys import getpath
5
 
6
-
7
  def sample_latvec(n=1, device='cpu', distribuion='uniform'):
8
  if distribuion == 'uniform':
9
  return torch.rand(n, nz, 1, 1, device=device) * 2 - 1
@@ -12,6 +12,7 @@ def sample_latvec(n=1, device='cpu', distribuion='uniform'):
12
  else:
13
  raise TypeError(f'unknow noise distribution: {distribuion}')
14
 
 
15
  def process_onehot(raw_tensor_onehot):
16
  H, W = MarioLevel.height, MarioLevel.seg_width
17
  res = []
@@ -26,6 +27,5 @@ def get_decoder(path='models/decoder.pth', device='cpu'):
26
  decoder.requires_grad_(False)
27
  decoder.eval()
28
  return decoder
29
- pass
30
 
31
 
 
3
  from src.gan.gans import nz
4
  from src.utils.filesys import getpath
5
 
6
+ # 采样噪声
7
  def sample_latvec(n=1, device='cpu', distribuion='uniform'):
8
  if distribuion == 'uniform':
9
  return torch.rand(n, nz, 1, 1, device=device) * 2 - 1
 
12
  else:
13
  raise TypeError(f'unknow noise distribution: {distribuion}')
14
 
15
+ # 处理onehot数组
16
  def process_onehot(raw_tensor_onehot):
17
  H, W = MarioLevel.height, MarioLevel.seg_width
18
  res = []
 
27
  decoder.requires_grad_(False)
28
  decoder.eval()
29
  return decoder
 
30
 
31
 
src/gan/gans.py CHANGED
@@ -5,7 +5,7 @@ from src.utils.dl import SelfAttn
5
 
6
  nz = 20
7
 
8
-
9
  class SAGenerator(nn.Module):
10
  def __init__(self, base_channels=32):
11
  super(SAGenerator, self).__init__()
 
5
 
6
  nz = 20
7
 
8
+ # Self Attention GAN
9
  class SAGenerator(nn.Module):
10
  def __init__(self, base_channels=32):
11
  super(SAGenerator, self).__init__()
src/olgen/olg_policy.py CHANGED
@@ -1,6 +1,6 @@
1
  import glob
2
  import random
3
- from abc import abstractmethod, abstractstaticmethod
4
  import numpy as np
5
  import torch
6
  from src.utils.filesys import getpath
@@ -49,8 +49,6 @@ class RLGenPolicy(GenPolicy):
49
  if d < nz * self.n:
50
  obs = torch.cat([torch.zeros([b, nz * self.n - d], device=self.device), obs], dim=-1)
51
  with torch.no_grad():
52
- # mus, sigmas, betas = self.model.get_intermediate(obs)
53
- # print(mus[0].cpu().numpy(), '\n', betas[0].cpu().numpy(), '\n')
54
  model_output, _ = self.model(obs)
55
  return torch.clamp(model_output, -1, 1).squeeze().cpu().numpy()
56
 
@@ -60,43 +58,6 @@ class RLGenPolicy(GenPolicy):
60
  n = load_cfgs(path, 'N')
61
  return RLGenPolicy(model, n, device)
62
 
63
- #
64
- # class SunriseGenPolicy(GenPolicy):
65
- # def __init__(self, models, n, device='cpu'):
66
- # super(SunriseGenPolicy, self).__init__(n)
67
- # for model in models:
68
- # model.to(device)
69
- # self.models = models
70
- # self.m = len(self.models)
71
- #
72
- # self.agent = SunriseProxyAgent(models, device)
73
- #
74
- # def step(self, obs):
75
- # actions = [m(obs.unsqueeze()).squeeze().cpu().numpy() for m in self.models]
76
- # if len(obs.shape) == 1:
77
- # return random.choice(actions)
78
- # else:
79
- # actions = np.array(actions)
80
- # selections = [random.choice(range(self.m)) for _ in range(len(obs))]
81
- # selected = [actions[s, i, :] for i, s in enumerate(selections)]
82
- # return np.array(selected)
83
- # #
84
- # # def reset(self):
85
- # # # self.agent.reset()
86
- # # pass
87
- #
88
- # @staticmethod
89
- # def from_path(path, device='cpu'):
90
- # models = [
91
- # torch.load(p, map_location=device)
92
- # for p in glob.glob(getpath(path, 'policy*.pth'))
93
- # ]
94
- # n = load_cfgs(path, 'N')
95
- # return SunriseGenPolicy(models, n, device)
96
- #
97
- # # @property
98
- # # def device(self):
99
- # # return self.agent.device
100
 
101
 
102
  class EnsembleGenPolicy(GenPolicy):
@@ -114,20 +75,25 @@ class EnsembleGenPolicy(GenPolicy):
114
  actions = []
115
  with torch.no_grad():
116
  for m in self.models:
117
- a = m(o)
118
  if type(a) == tuple:
119
  a = a[0]
120
  actions.append(torch.clamp(a, -1, 1).cpu().numpy())
121
  if len(obs.shape) == 1:
122
  return random.choice(actions)
123
  else:
 
124
  actions = np.array(actions)
 
125
  selections = [random.choice(range(self.m)) for _ in range(len(obs))]
126
  selected = [actions[s, i, :] for i, s in enumerate(selections)]
127
  return np.array(selected)
128
 
129
  @staticmethod
130
  def from_path(path, device='cpu'):
 
 
 
131
  models = [
132
  torch.load(p, map_location=device)
133
  for p in glob.glob(getpath(path, 'policy*.pth'))
@@ -141,9 +107,6 @@ class RandGenPolicy(GenPolicy):
141
  super(RandGenPolicy, self).__init__(1)
142
 
143
  def step(self, obs):
144
- # if len(obs.shape) == 1:
145
- # return sample_latvec(1).squeeze().numpy()
146
- # else:
147
  n = obs.shape[0]
148
  return sample_latvec(n).squeeze().numpy()
149
 
 
1
  import glob
2
  import random
3
+ from abc import abstractmethod
4
  import numpy as np
5
  import torch
6
  from src.utils.filesys import getpath
 
49
  if d < nz * self.n:
50
  obs = torch.cat([torch.zeros([b, nz * self.n - d], device=self.device), obs], dim=-1)
51
  with torch.no_grad():
 
 
52
  model_output, _ = self.model(obs)
53
  return torch.clamp(model_output, -1, 1).squeeze().cpu().numpy()
54
 
 
58
  n = load_cfgs(path, 'N')
59
  return RLGenPolicy(model, n, device)
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
 
63
  class EnsembleGenPolicy(GenPolicy):
 
75
  actions = []
76
  with torch.no_grad():
77
  for m in self.models:
78
+ a = m(o) # action model predict
79
  if type(a) == tuple:
80
  a = a[0]
81
  actions.append(torch.clamp(a, -1, 1).cpu().numpy())
82
  if len(obs.shape) == 1:
83
  return random.choice(actions)
84
  else:
85
+ # 这里对于每个observation, 选择m个模型, 每个模型都输出一个动作, 然后随机选择其中一个动作
86
  actions = np.array(actions)
87
+ # 这里的self.m就是模型的数量, 等价于len(self.models)
88
  selections = [random.choice(range(self.m)) for _ in range(len(obs))]
89
  selected = [actions[s, i, :] for i, s in enumerate(selections)]
90
  return np.array(selected)
91
 
92
  @staticmethod
93
  def from_path(path, device='cpu'):
94
+ """
95
+ 读取path中的所有模型
96
+ """
97
  models = [
98
  torch.load(p, map_location=device)
99
  for p in glob.glob(getpath(path, 'policy*.pth'))
 
107
  super(RandGenPolicy, self).__init__(1)
108
 
109
  def step(self, obs):
 
 
 
110
  n = obs.shape[0]
111
  return sample_latvec(n).squeeze().numpy()
112
 
src/smb/asyncsimlt.py CHANGED
@@ -54,14 +54,11 @@ def _simlt_worker(remote, parent_remote, rfunc, resource):
54
  min_dtw = min(min_dtw, vdtw)
55
  remote.send((min_hm, min_dtw))
56
  elif cmd == 'mpd':
57
- # strpairs = data
58
  hms, dtws = [], []
59
  for strlvl1, strlvl2 in data:
60
  lvl1, lvl2 = MarioLevel(strlvl1), MarioLevel(strlvl2)
61
  hms.append(hamming_dis(lvl1, lvl2))
62
- # dtws.append(lvl_dtw(lvl1, lvl2))
63
  remote.send((hms, None))
64
- # remote.send((hms, dtws))
65
  else:
66
  raise KeyError(f'Unknown command for simulation worker: {cmd}')
67
  except EOFError:
@@ -70,6 +67,9 @@ def _simlt_worker(remote, parent_remote, rfunc, resource):
70
 
71
 
72
  class AsycSimltPool:
 
 
 
73
  def __init__(self, poolsize, queuesize=None, rfunc_name='default', verbose=True, **rsrc):
74
  self.np, self.nq = poolsize, poolsize if queuesize is None else queuesize
75
  self.waiting_queue = Queue(self.nq)
@@ -149,6 +149,7 @@ class AsycSimltPool:
149
  for work_remote, remote in zip(self.work_remotes, self.remotes):
150
  args = (work_remote, remote, rfunc, resource)
151
  # daemon=True: if the main process crashes, we should not cause things to hang
 
152
  process = ctx.Process(target=_simlt_worker, args=args, daemon=True) # pytype:disable=attribute-error
153
  process.start()
154
  self.processes.append(process)
@@ -162,12 +163,6 @@ class AsycSimltPool:
162
  time.sleep(0.01)
163
 
164
  def close(self):
165
- # finish = False
166
- # while not finish:
167
- # self.refresh()
168
- # finish = all(r for r in self.ready)
169
- # time.sleep(0.01)
170
- # self.__wait()
171
  res = self.get(True)
172
  for remote, p in zip(self.remotes, self.processes):
173
  remote.send(('close', None))
 
54
  min_dtw = min(min_dtw, vdtw)
55
  remote.send((min_hm, min_dtw))
56
  elif cmd == 'mpd':
 
57
  hms, dtws = [], []
58
  for strlvl1, strlvl2 in data:
59
  lvl1, lvl2 = MarioLevel(strlvl1), MarioLevel(strlvl2)
60
  hms.append(hamming_dis(lvl1, lvl2))
 
61
  remote.send((hms, None))
 
62
  else:
63
  raise KeyError(f'Unknown command for simulation worker: {cmd}')
64
  except EOFError:
 
67
 
68
 
69
  class AsycSimltPool:
70
+ """
71
+ 异步池, 用于多进程马里奥模拟任务
72
+ """
73
  def __init__(self, poolsize, queuesize=None, rfunc_name='default', verbose=True, **rsrc):
74
  self.np, self.nq = poolsize, poolsize if queuesize is None else queuesize
75
  self.waiting_queue = Queue(self.nq)
 
149
  for work_remote, remote in zip(self.work_remotes, self.remotes):
150
  args = (work_remote, remote, rfunc, resource)
151
  # daemon=True: if the main process crashes, we should not cause things to hang
152
+ # 开启多进程来做异步计算
153
  process = ctx.Process(target=_simlt_worker, args=args, daemon=True) # pytype:disable=attribute-error
154
  process.start()
155
  self.processes.append(process)
 
163
  time.sleep(0.01)
164
 
165
  def close(self):
 
 
 
 
 
 
166
  res = self.get(True)
167
  for remote, p in zip(self.remotes, self.processes):
168
  remote.send(('close', None))
src/smb/proxy.py CHANGED
@@ -9,10 +9,6 @@ from src.smb.level import MarioLevel, LevelRender
9
  from src.utils.filesys import getpath
10
 
11
  JVMPath = None
12
- # JVMPath = '/home/cseadmin/java/jdk1.8.0_301/jre/lib/amd64/server/libjvm.so'
13
- # JVMPath = '/home/liujl_lab/12132362/java/jdk1.8.0_301/jre/lib/amd64/server/libjvm.so'
14
- # JVMPath = '/home/liujl_lab/12132333/java/jdk1.8.0_301/jre/lib/amd64/server/libjvm.so'
15
-
16
 
17
  class MarioJavaAgents(Enum):
18
  Runner = 'agents.robinBaumgarten'
@@ -27,7 +23,6 @@ class MarioProxy:
27
  def __init__(self):
28
  if not jpype.isJVMStarted():
29
  jar_path = getpath('smb/Mario-AI-Framework.jar')
30
- # print(f"-Djava.class.path={jar_path}/Mario-AI-Framework.jar")
31
  jpype.startJVM(
32
  jpype.getDefaultJVMPath() if JVMPath is None else JVMPath,
33
  f"-Djava.class.path={jar_path}", '-Xmx2g'
@@ -137,6 +132,3 @@ class MarioProxy:
137
 
138
  if __name__ == '__main__':
139
  simulator = MarioProxy()
140
- # lvl = MarioLevel.from_file('smb/levels/lvl-1.lvl')
141
- # print(simulator.simulate_complete(lvl))
142
- # print(simulator.play_game(lvl))
 
9
  from src.utils.filesys import getpath
10
 
11
  JVMPath = None
 
 
 
 
12
 
13
  class MarioJavaAgents(Enum):
14
  Runner = 'agents.robinBaumgarten'
 
23
  def __init__(self):
24
  if not jpype.isJVMStarted():
25
  jar_path = getpath('smb/Mario-AI-Framework.jar')
 
26
  jpype.startJVM(
27
  jpype.getDefaultJVMPath() if JVMPath is None else JVMPath,
28
  f"-Djava.class.path={jar_path}", '-Xmx2g'
 
132
 
133
  if __name__ == '__main__':
134
  simulator = MarioProxy()
 
 
 
src/utils/img.py CHANGED
@@ -10,7 +10,6 @@ def make_img_sheet(imgs, ncols, x_margin=6, y_margin=6, save_path='./image.png',
10
  w_canvas = (w + x_margin) * ncols - x_margin
11
  h_canvas = (h + y_margin) * nrows - y_margin
12
  canvas = Image.new('RGBA', (w_canvas, h_canvas), (0, 0, 0, 0))
13
- # canvas.fill(margin_color)
14
  for i in range(len(imgs)):
15
  row_id, col_id = i // ncols, i % ncols
16
  canvas.paste(imgs[i], ((w + x_margin) * col_id, (h + y_margin) * row_id), imgs[i])
 
10
  w_canvas = (w + x_margin) * ncols - x_margin
11
  h_canvas = (h + y_margin) * nrows - y_margin
12
  canvas = Image.new('RGBA', (w_canvas, h_canvas), (0, 0, 0, 0))
 
13
  for i in range(len(imgs)):
14
  row_id, col_id = i // ncols, i % ncols
15
  canvas.paste(imgs[i], ((w + x_margin) * col_id, (h + y_margin) * row_id), imgs[i])
test_ddpm.py CHANGED
@@ -4,16 +4,10 @@ import torch
4
  import torch.optim as optim
5
  import torch.nn as nn
6
  import logging
7
- # from tqdm import tqdm
8
- # from torch.utils.tensorboard import SummaryWriter
9
  from src.ddpm.diffusion import Diffusion
10
  from src.ddpm.modules import UNet
11
- # from pytorch_model_summary import summary
12
- # from matplotlib import pyplot as plt
13
  from src.ddpm.dataset import create_dataloader
14
- # from utils.plot import get_img_from_level
15
  from pathlib import Path
16
- # from src.smb.level import MarioLevel
17
  import argparse
18
  import datetime
19
 
@@ -21,16 +15,12 @@ from src.gan.gankits import process_onehot, get_decoder
21
  from src.smb.level import MarioLevel, lvlhcat, save_batch
22
  from src.utils.filesys import getpath
23
  from src.utils.img import make_img_sheet
24
-
25
- # sprite_counts = np.power(np.array([102573, 9114, 1017889, 930, 3032, 7330, 2278, 2279, 5227, 5229, 5419]), 1/4)
26
  sprite_counts = np.power(np.array([
27
  74977, 15252, 572591, 5826, 1216, 7302, 237, 237, 2852, 1074, 235, 304, 48, 96, 160, 1871, 936, 186, 428, 80, 428
28
  ]), 1/4
29
  )
30
  min_count = np.min(sprite_counts)
31
 
32
- # filepath = Path(__file__).parent.resolve()
33
- # DATA_PATH = os.path.join(filepath, "levels", "ground", "unique_onehot.npz")
34
 
35
  def setup_logging(run_name, beta_schedule):
36
  model_path = os.path.join("models", beta_schedule, run_name)
@@ -39,41 +29,8 @@ def setup_logging(run_name, beta_schedule):
39
  os.makedirs(result_path, exist_ok=True)
40
  return model_path, result_path
41
 
42
- # def plot_images(epoch, sampled_images, result_path):
43
- # fig = plt.figure(figsize=(30, 15))
44
- # for i in range(len(sampled_images)):
45
- # ax1 = fig.add_subplot(4, int(len(sampled_images)/4), i+1)
46
- # ax1.tick_params(left=False, right=False, labelleft=False, labelbottom=False, bottom=False)
47
- # level = sampled_images[i].argmax(dim=0).cpu().numpy()
48
- # level_img = get_img_from_level(level)
49
- # ax1.imshow(level_img)
50
- # plt.savefig(os.path.join(result_path, f"{epoch:04d}_sample.png"))
51
- # plt.close()
52
-
53
- # def plot_training_images(epoch, original_img, x_t, noise, predicted_noise, reconstructed_img, training_result_path):
54
- # fig = plt.figure(figsize=(15, 10))
55
- # for i in range(2):
56
- # ax1 = fig.add_subplot(2, 5, i*5+1)
57
- # ax1.imshow(get_img_from_level(original_img[i].cpu().numpy()))
58
- # ax1.set_title(f"Original {i}")
59
- # ax2 = fig.add_subplot(2, 5, i*5+2)
60
- # ax2.imshow(get_img_from_level(noise[i].cpu().numpy()))
61
- # ax2.set_title(f"Noise {i}")
62
- # ax3 = fig.add_subplot(2, 5, i*5+3)
63
- # ax3.imshow(get_img_from_level(x_t.argmax(dim=1).cpu().numpy()[i]))
64
- # ax3.set_title(f"x_t {i}")
65
- # ax4 = fig.add_subplot(2, 5, i*5+4)
66
- # ax4.imshow(get_img_from_level(predicted_noise[i].cpu().numpy()))
67
- # ax4.set_title(f"Predicted Noise {i}")
68
- # ax5 = fig.add_subplot(2, 5, i*5+5)
69
- # ax5.imshow(get_img_from_level(reconstructed_img.probs.argmax(dim=-1).cpu().numpy()[i]))
70
- # ax5.set_title(f"Reconstructed Image {i}")
71
- # plt.savefig(os.path.join(training_result_path, f"{epoch:04d}.png"))
72
- # plt.close()
73
-
74
  def train(args):
75
- # model_path, result_path = setup_logging(args.run_name, args.beta_schedule)
76
- # training_result_path = os.path.join(result_path, "training")
77
  path = getpath(args.res_path)
78
  os.makedirs(path, exist_ok=True)
79
 
@@ -83,24 +40,15 @@ def train(args):
83
  optimizer = optim.AdamW(model.parameters(), lr=args.lr)
84
  mse = nn.MSELoss()
85
  diffusion = Diffusion(device=device, schedule=args.beta_schedule)
86
- # logger = SummaryWriter(os.path.join("logs", args.beta_schedule, args.run_name))
87
  temperatures = torch.tensor(min_count / sprite_counts, dtype=torch.float32).to(device)
88
  l = len(dataloader)
89
 
90
- # print(summary(model, torch.zeros((64, MarioLevel.n_types, 14, 14)).to(device), diffusion.sample_timesteps(64).to(device), show_input=True))
91
-
92
- # if args.resume_from != 0:
93
- # checkpoint = torch.load(os.path.join(model_path, f'ckpt_{args.resume_from}'))
94
- # model.load_state_dict(checkpoint['model_state_dict'])
95
- # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
96
 
97
  for epoch in range(args.resume_from+1, args.resume_from+args.epochs+1):
98
  logging.info(f"Starting epoch {epoch}:")
99
  epoch_loss = {'rec_loss': 0, 'mse': 0, 'loss': 0}
100
- # pbar = tqdm(dataloader)
101
  for i, images in enumerate(dataloader):
102
  images = images.to(device)
103
- # print(images.shape)
104
  t = diffusion.sample_timesteps(images.shape[0]).to(device) # random int from 1~1000
105
  x_t, noise = diffusion.noise_images(images, t) # x_t: image with noise at t, noise: gaussian noise
106
  predicted_noise = model(x_t.float(), t.float()) # returns predicted noise eps_theta
@@ -117,38 +65,13 @@ def train(args):
117
  loss.backward()
118
  optimizer.step()
119
 
120
- # pbar.set_postfix(LOSS=loss.item())
121
- # logger.add_scalar("Rec_loss", rec_loss.item(), global_step=(epoch - 1) * l + i)
122
- # logger.add_scalar("MSE", mse_loss.item(), global_step=(epoch - 1) * l + i)
123
- # logger.add_scalar("LOSS", loss.item(), global_step=(epoch - 1) * l + i)
124
-
125
- # logger.add_scalar("Epoch_Rec_loss", epoch_loss['rec_loss']/l, global_step=epoch)
126
- # logger.add_scalar("Epoch_MSE", epoch_loss['mse']/l, global_step=epoch)
127
- # logger.add_scalar("Epoch_LOSS", epoch_loss['loss']/l, global_step=epoch)
128
  print(
129
  '\nIteration: %d' % epoch,
130
  'rec_loss: %.5g' % (epoch_loss['rec_loss']/l),
131
  'mse: %.5g' % (epoch_loss['mse']/l)
132
  )
133
 
134
- # if epoch % 20 == 19:
135
- # sampled_images = diffusion.sample(model, n=50)
136
- # imgs = [lvl.to_img() for lvl in process_onehot(sampled_images[-1])]
137
- # make_img_sheet(imgs, 10, save_path=f'{args.res_path}/sample{epoch+1}.png')
138
-
139
- # plot_images(epoch, sampled_images[-1], result_path)
140
- # plot_training_images(epoch, original_img, x_t, noise.argmax(dim=1), predicted_noise.argmax(dim=1), reconstructed_img, training_result_path)
141
-
142
  if epoch % 1000 == 0:
143
- # torch.save(model.state_dict(), os.path.join(model_path, f"ckpt_{epoch:04d}.pt"))
144
- # torch.save({
145
- # 'epoch': epoch,
146
- # 'model_state_dict': model.state_dict(),
147
- # 'optimizer_state_dict': optimizer.state_dict(),
148
- # 'Epoch_Rec_loss': epoch_loss['rec_loss']/l,
149
- # 'Epoch_MSE': epoch_loss['mse']/l,
150
- # 'Epoch_LOSS': epoch_loss['loss']/l
151
- # }, getpath(f"{args.res_path}/ddpm_{epoch}.pt"))
152
  itpath = getpath(path, f'it{epoch}')
153
  os.makedirs(itpath, exist_ok=True)
154
  model.save(getpath(path, itpath, 'ddpm.pth'))
@@ -173,11 +96,8 @@ def train(args):
173
  def launch():
174
  parser = argparse.ArgumentParser()
175
  parser.add_argument("--epochs", type=int, default=10000)
176
- # parser.add_argument("--data_path", type=str, default=DATA_PATH)
177
  parser.add_argument("--batch_size", type=int, default=256)
178
  parser.add_argument("--res_path", type=str, default='exp_data/DDPM')
179
- # parser.add_argument("--image_size", type=int, default=14)
180
- # parser.add_argument("--device", type=str, default="cuda")
181
  parser.add_argument("--gpuid", type=int, default=0)
182
  parser.add_argument("--lr", type=float, default=3e-4)
183
  parser.add_argument("--beta_schedule", type=str, default="quadratic", choices=['linear', 'quadratic', 'sigmoid'])
 
4
  import torch.optim as optim
5
  import torch.nn as nn
6
  import logging
 
 
7
  from src.ddpm.diffusion import Diffusion
8
  from src.ddpm.modules import UNet
 
 
9
  from src.ddpm.dataset import create_dataloader
 
10
  from pathlib import Path
 
11
  import argparse
12
  import datetime
13
 
 
15
  from src.smb.level import MarioLevel, lvlhcat, save_batch
16
  from src.utils.filesys import getpath
17
  from src.utils.img import make_img_sheet
 
 
18
  sprite_counts = np.power(np.array([
19
  74977, 15252, 572591, 5826, 1216, 7302, 237, 237, 2852, 1074, 235, 304, 48, 96, 160, 1871, 936, 186, 428, 80, 428
20
  ]), 1/4
21
  )
22
  min_count = np.min(sprite_counts)
23
 
 
 
24
 
25
  def setup_logging(run_name, beta_schedule):
26
  model_path = os.path.join("models", beta_schedule, run_name)
 
29
  os.makedirs(result_path, exist_ok=True)
30
  return model_path, result_path
31
 
32
+ # 测试DDPM的模型训练
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def train(args):
 
 
34
  path = getpath(args.res_path)
35
  os.makedirs(path, exist_ok=True)
36
 
 
40
  optimizer = optim.AdamW(model.parameters(), lr=args.lr)
41
  mse = nn.MSELoss()
42
  diffusion = Diffusion(device=device, schedule=args.beta_schedule)
 
43
  temperatures = torch.tensor(min_count / sprite_counts, dtype=torch.float32).to(device)
44
  l = len(dataloader)
45
 
 
 
 
 
 
 
46
 
47
  for epoch in range(args.resume_from+1, args.resume_from+args.epochs+1):
48
  logging.info(f"Starting epoch {epoch}:")
49
  epoch_loss = {'rec_loss': 0, 'mse': 0, 'loss': 0}
 
50
  for i, images in enumerate(dataloader):
51
  images = images.to(device)
 
52
  t = diffusion.sample_timesteps(images.shape[0]).to(device) # random int from 1~1000
53
  x_t, noise = diffusion.noise_images(images, t) # x_t: image with noise at t, noise: gaussian noise
54
  predicted_noise = model(x_t.float(), t.float()) # returns predicted noise eps_theta
 
65
  loss.backward()
66
  optimizer.step()
67
 
 
 
 
 
 
 
 
 
68
  print(
69
  '\nIteration: %d' % epoch,
70
  'rec_loss: %.5g' % (epoch_loss['rec_loss']/l),
71
  'mse: %.5g' % (epoch_loss['mse']/l)
72
  )
73
 
 
 
 
 
 
 
 
 
74
  if epoch % 1000 == 0:
 
 
 
 
 
 
 
 
 
75
  itpath = getpath(path, f'it{epoch}')
76
  os.makedirs(itpath, exist_ok=True)
77
  model.save(getpath(path, itpath, 'ddpm.pth'))
 
96
  def launch():
97
  parser = argparse.ArgumentParser()
98
  parser.add_argument("--epochs", type=int, default=10000)
 
99
  parser.add_argument("--batch_size", type=int, default=256)
100
  parser.add_argument("--res_path", type=str, default='exp_data/DDPM')
 
 
101
  parser.add_argument("--gpuid", type=int, default=0)
102
  parser.add_argument("--lr", type=float, default=3e-4)
103
  parser.add_argument("--beta_schedule", type=str, default="quadratic", choices=['linear', 'quadratic', 'sigmoid'])
test_gen_log.py DELETED
@@ -1,15 +0,0 @@
1
- import time
2
- import argparse
3
- from tests import evaluate_rewards, evaluate_gen_log
4
-
5
-
6
- if __name__ == '__main__':
7
- parser = argparse.ArgumentParser()
8
- parser.add_argument('--path', type=str)
9
- parser.add_argument('--parallel', type=int, default=50)
10
- parser.add_argument('--rfunc', type=str)
11
- args = parser.parse_args()
12
- start = time.time()
13
- evaluate_gen_log(args.path, args.rfunc, parallel=args.parallel)
14
- print(f'Evaluation for {args.path} finished,', '%.2f' % (time.time() - start))
15
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_gen_samples.py DELETED
@@ -1,24 +0,0 @@
1
- import json
2
- import argparse
3
- import time
4
-
5
- import numpy as np
6
- from tests import evaluate_rewards, evaluate_mpd
7
- from src.smb.level import load_batch
8
- from src.utils.filesys import getpath
9
-
10
- if __name__ == '__main__':
11
- parser = argparse.ArgumentParser()
12
- parser.add_argument('--path', type=str)
13
- parser.add_argument('--parallel', type=int, default=50)
14
- parser.add_argument('--rfunc', type=str)
15
- args = parser.parse_args()
16
- start = time.time()
17
- lvls = load_batch(getpath(args.path, 'samples.lvls'))
18
- rewards = [sum(item) for item in evaluate_rewards(lvls, args.rfunc, parallel=args.parallel)]
19
- diversity = evaluate_mpd(lvls)
20
- with open(getpath(args.path, 'performance.csv'), 'w') as f:
21
- json.dump({'reward': np.mean(rewards), 'diversity': diversity}, f)
22
- print(f'Evaluation for {args.path} finished,', '%.2f' % (time.time() - start))
23
-
24
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests.py DELETED
@@ -1,140 +0,0 @@
1
- import csv
2
- import time
3
-
4
- import torch
5
-
6
- from plots import print_compare_tab_nonrl
7
- from src.gan.gankits import *
8
- from src.smb.level import *
9
- from itertools import combinations, chain
10
- from src.utils.filesys import getpath
11
- from src.smb.asyncsimlt import AsycSimltPool
12
-
13
-
14
- def evaluate_rewards(lvls, rfunc='default', dest_path='', parallel=1, eval_pool=None):
15
- internal_pool = eval_pool is None
16
- if internal_pool:
17
- eval_pool = AsycSimltPool(parallel, rfunc_name=rfunc, verbose=False, test=True)
18
- res = []
19
- for lvl in lvls:
20
- eval_pool.put('evaluate', (0, str(lvl)))
21
- buffer = eval_pool.get()
22
- for _, item in buffer:
23
- res.append([sum(r) for r in zip(*item.values())])
24
- if internal_pool:
25
- buffer = eval_pool.close()
26
- else:
27
- buffer = eval_pool.get(True)
28
- for _, item in buffer:
29
- res.append([sum(r) for r in zip(*item.values())])
30
- if len(dest_path):
31
- np.save(dest_path, res)
32
- return res
33
-
34
- def evaluate_mpd(lvls, parallel=2):
35
- task_datas = [[] for _ in range(parallel)]
36
- for i, (A, B) in enumerate(combinations(lvls, 2)):
37
- # lvlA, lvlB = lvls[i * 2], lvls[i * 2 + 1]
38
- task_datas[i % parallel].append((str(A), str(B)))
39
-
40
- hms, dtws = [], []
41
- eval_pool = AsycSimltPool(parallel, verbose=False)
42
- for task_data in task_datas:
43
- eval_pool.put('mpd', task_data)
44
- res = eval_pool.get(wait=True)
45
- for task_hms, _ in res:
46
- hms += task_hms
47
- return np.mean(hms)
48
-
49
- def evaluate_gen_log(path, rfunc_name, parallel=5):
50
- f = open(getpath(f'{path}/step_tests.csv'), 'w', newline='')
51
- wrtr = csv.writer(f)
52
- cols = ['step', 'r-avg', 'r-std', 'diversity']
53
- wrtr.writerow(cols)
54
- start_time = time.time()
55
- for lvls, name in traverse_batched_level_files(f'{path}/gen_log'):
56
- step = name[4:]
57
- rewards = [sum(item) for item in evaluate_rewards(lvls, rfunc_name, parallel=parallel)]
58
- r_avg, r_std = np.mean(rewards), np.std(rewards)
59
- mpd = evaluate_mpd(lvls, parallel=parallel)
60
- line = [step, r_avg, r_std, mpd]
61
- wrtr.writerow(line)
62
- f.flush()
63
- print(
64
- f'{path}: step{step} evaluated in {time.time()-start_time:.1f}s -- '
65
- + '; '.join(f'{k}: {v}' for k, v in zip(cols, line))
66
- )
67
- f.close()
68
- pass
69
-
70
-
71
-
72
-
73
- if __name__ == '__main__':
74
- # print_compare_tab_nonrl()
75
-
76
- arr = [[1, 2], [1, 2]]
77
- arr = [*chain(*arr)]
78
- print(arr)
79
- for i in range(5):
80
- path = f'training_data/GAN{i}'
81
- lvls = []
82
- init_lateves = torch.tensor(np.load(getpath('analysis/initial_seg.npy')), device='cuda:0')
83
- decoder = get_decoder(device='cuda:0')
84
- init_seg_onehots = decoder(init_lateves.view(*init_lateves.shape, 1, 1))
85
- gan = get_decoder(f'{path}/decoder.pth', device='cuda:0')
86
- for init_seg_onehot in init_seg_onehots:
87
- seg_onehots = gan(sample_latvec(25, device='cuda:0'))
88
- a = init_seg_onehot.view(1, *init_seg_onehot.shape)
89
- b = seg_onehots
90
- # print(a.shape, b.shape)
91
- segs = process_onehot(torch.cat([a, b], dim=0))
92
- level = lvlhcat(segs)
93
- lvls.append(level)
94
- save_batch(lvls, getpath(path, 'samples.lvls'))
95
- lvls = load_batch(f'{path}/samples.lvls')[:15]
96
- imgs = [lvl.to_img() for lvl in lvls]
97
- make_img_sheet(imgs, 1, save_path=f'generation_results/GAN/trial{i+1}/sample_lvls.png')
98
-
99
- ts = torch.tensor([
100
- [[0, 0], [0, 1], [0, 2]],
101
- [[1, 0], [1, 1], [1, 2]],
102
- ])
103
- print(ts.shape)
104
- print(ts[[*range(2)], [1, 2], :])
105
- task = 'fhp'
106
- parallel = 50
107
- samples = []
108
- for algo in ['dvd', 'egsac', 'pmoe', 'sunrise', 'asyncsac', 'sac']:
109
- for t in range(5):
110
- lvls = load_batch(getpath('test_data', algo, task, f't{t + 1}', 'samples.lvls'))
111
- samples += lvls
112
- for l in ['0.0', '0.1', '0.2', '0.3', '0.4', '0.5']:
113
- for t in range(5):
114
- lvls = load_batch(getpath('test_data', f'varpm-{task}', f'l{l}_m5', f't{t + 1}', 'samples.lvls'))
115
- samples += lvls
116
-
117
- # task_datas = [[] for _ in range(parallel)]
118
- # for i, (A, B) in enumerate(combinations(samples, 2)):
119
- # lvlA, lvlB = lvls[i * 2], lvls[i * 2 + 1]
120
- # task_datas[i % parallel].append((str(A), str(B)))
121
-
122
- distmat = []
123
- eval_pool = AsycSimltPool(parallel, verbose=False)
124
- for A in samples:
125
- eval_pool.put('mpd', [(str(A), str(B)) for B in samples])
126
- res = eval_pool.get()
127
- for task_hms, _ in res:
128
- hms += task_hms
129
- np.save(getpath('test_data', f'samples_dists-{task}.npy'), hms)
130
-
131
- start = time.time()
132
- samples = load_batch(getpath('test_data/varpm-fhp/l0.0_m2/t1/samples.lvls'))
133
- distmat = []
134
- for a in samples:
135
- dist_list = []
136
- for b in samples:
137
- dist_list.append(hamming_dis(a, b))
138
- distmat.append(dist_list)
139
- print(time.time() - start)
140
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train.py CHANGED
@@ -46,4 +46,6 @@ if __name__ == '__main__':
46
  args = parser.parse_args()
47
 
48
  entry = args.entry
 
 
49
  entry(args)
 
46
  args = parser.parse_args()
47
 
48
  entry = args.entry
49
+
50
+ # entry是每一个模型的训练入口,具体函数在各个subparser内定义
51
  entry(args)