santit96 commited on
Commit
c412087
β€’
1 Parent(s): c10a05f

Fix code style with black and isort

Browse files
a3c/eval.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import torch
3
 
4
  from .net import GreedyNet
@@ -14,9 +15,7 @@ def evaluate_checkpoints(dir, env):
14
  wins, guesses = evaluate(env, pretrained_model_path)
15
  results[checkpoint] = wins, guesses
16
  return dict(
17
- sorted(results.items(), key=lambda x: (
18
- x[1][0], -x[1][1]), reverse=True
19
- )
20
  )
21
 
22
 
@@ -39,4 +38,4 @@ def evaluate(env, pretrained_model_path):
39
  took {n_win_guesses/n_wins} guesses per win, "
40
  f"{n_guesses / N} including losses."
41
  )
42
- return n_wins/N*100, n_win_guesses/n_wins
 
1
  import os
2
+
3
  import torch
4
 
5
  from .net import GreedyNet
 
15
  wins, guesses = evaluate(env, pretrained_model_path)
16
  results[checkpoint] = wins, guesses
17
  return dict(
18
+ sorted(results.items(), key=lambda x: (x[1][0], -x[1][1]), reverse=True)
 
 
19
  )
20
 
21
 
 
38
  took {n_win_guesses/n_wins} guesses per win, "
39
  f"{n_guesses / N} including losses."
40
  )
41
+ return n_wins / N * 100, n_win_guesses / n_wins
a3c/net.py CHANGED
@@ -1,7 +1,7 @@
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
- import numpy as np
5
 
6
 
7
  class Net(nn.Module):
@@ -23,15 +23,15 @@ class Net(nn.Module):
23
  word_array = np.zeros((word_width, len(word_list)))
24
  for i, word in enumerate(word_list):
25
  for j, c in enumerate(word):
26
- word_array[j*26 + (ord(c) - ord('A')), i] = 1
27
  self.words = torch.Tensor(word_array)
28
 
29
  def forward(self, x):
30
  values = self.v1(x.float())
31
  logits = torch.log_softmax(
32
- torch.tensordot(self.actor_head(values), self.words,
33
- dims=((1,), (0,))),
34
- dim=-1)
35
  values = self.v4(values)
36
  return logits, values
37
 
 
1
+ import numpy as np
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
 
5
 
6
 
7
  class Net(nn.Module):
 
23
  word_array = np.zeros((word_width, len(word_list)))
24
  for i, word in enumerate(word_list):
25
  for j, c in enumerate(word):
26
+ word_array[j * 26 + (ord(c) - ord("A")), i] = 1
27
  self.words = torch.Tensor(word_array)
28
 
29
  def forward(self, x):
30
  values = self.v1(x.float())
31
  logits = torch.log_softmax(
32
+ torch.tensordot(self.actor_head(values), self.words, dims=((1,), (0,))),
33
+ dim=-1,
34
+ )
35
  values = self.v4(values)
36
  return logits, values
37
 
a3c/play.py CHANGED
@@ -1,15 +1,18 @@
1
  import os
 
2
  import torch
3
  from dotenv import load_dotenv
 
4
  from wordle_env.state import update_from_mask
 
5
  from .net import GreedyNet
6
  from .utils import v_wrap
7
 
8
 
9
  def get_play_model_path():
10
  load_dotenv()
11
- model_name = os.getenv('RS_WORDLE_MODEL_NAME')
12
- model_checkpoint_dir = os.path.join('checkpoints', 'best_models')
13
  return os.path.join(model_checkpoint_dir, model_name)
14
 
15
 
@@ -28,12 +31,7 @@ def get_initial_state(env):
28
  return state
29
 
30
 
31
- def suggest(
32
- env,
33
- words,
34
- states,
35
- pretrained_model_path
36
- ) -> str:
37
  """
38
  Given a list of words and masks, return the next suggested word
39
 
 
1
  import os
2
+
3
  import torch
4
  from dotenv import load_dotenv
5
+
6
  from wordle_env.state import update_from_mask
7
+
8
  from .net import GreedyNet
9
  from .utils import v_wrap
10
 
11
 
12
  def get_play_model_path():
13
  load_dotenv()
14
+ model_name = os.getenv("RS_WORDLE_MODEL_NAME")
15
+ model_checkpoint_dir = os.path.join("checkpoints", "best_models")
16
  return os.path.join(model_checkpoint_dir, model_name)
17
 
18
 
 
31
  return state
32
 
33
 
34
+ def suggest(env, words, states, pretrained_model_path) -> str:
 
 
 
 
 
35
  """
36
  Given a list of words and masks, return the next suggested word
37
 
a3c/shared_adam.py CHANGED
@@ -6,20 +6,18 @@ import torch
6
 
7
 
8
  class SharedAdam(torch.optim.Adam):
9
- def __init__(self, params, lr=1e-3, betas=(0.9, 0.99), eps=1e-8,
10
- weight_decay=0):
11
  super(SharedAdam, self).__init__(
12
- params, lr=lr,
13
- betas=betas, eps=eps, weight_decay=weight_decay
14
  )
15
  # State initialization
16
  for group in self.param_groups:
17
- for p in group['params']:
18
  state = self.state[p]
19
- state['step'] = 0
20
- state['exp_avg'] = torch.zeros_like(p.data)
21
- state['exp_avg_sq'] = torch.zeros_like(p.data)
22
 
23
  # share in memory
24
- state['exp_avg'].share_memory_()
25
- state['exp_avg_sq'].share_memory_()
 
6
 
7
 
8
  class SharedAdam(torch.optim.Adam):
9
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.99), eps=1e-8, weight_decay=0):
 
10
  super(SharedAdam, self).__init__(
11
+ params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay
 
12
  )
13
  # State initialization
14
  for group in self.param_groups:
15
+ for p in group["params"]:
16
  state = self.state[p]
17
+ state["step"] = 0
18
+ state["exp_avg"] = torch.zeros_like(p.data)
19
+ state["exp_avg_sq"] = torch.zeros_like(p.data)
20
 
21
  # share in memory
22
+ state["exp_avg"].share_memory_()
23
+ state["exp_avg_sq"].share_memory_()
a3c/train.py CHANGED
@@ -1,10 +1,12 @@
1
  import os
2
- import numpy as np
3
  import random
 
 
4
  import torch
5
  import torch.multiprocessing as mp
6
- from .shared_adam import SharedAdam
7
  from .net import Net
 
8
  from .worker import Worker
9
 
10
 
@@ -25,12 +27,12 @@ def train(
25
  env,
26
  max_ep,
27
  model_checkpoint_dir,
28
- gamma=0.,
29
  seed=100,
30
  pretrained_model_path=None,
31
  save=False,
32
  min_reward=9.9,
33
- every_n_save=100
34
  ):
35
  os.environ["OMP_NUM_THREADS"] = "1"
36
  if not os.path.exists(model_checkpoint_dir):
@@ -45,18 +47,40 @@ def train(
45
  if pretrained_model_path:
46
  gnet.load_state_dict(torch.load(pretrained_model_path))
47
  gnet.share_memory() # share the global parameters in multiprocessing
48
- opt = SharedAdam(gnet.parameters(), lr=1e-4,
49
- betas=(0.92, 0.999)) # global optimizer
50
- global_ep, global_ep_r, res_queue, win_ep = mp.Value(
51
- 'i', 0), mp.Value('d', 0.), mp.Queue(), mp.Value('i', 0)
 
 
 
 
 
52
 
53
  # parallel training
54
  workers = [
55
  Worker(
56
- max_ep, gnet, opt, global_ep, global_ep_r, res_queue, i, env,
57
- n_s, n_a, words_list, word_width, win_ep, model_checkpoint_dir,
58
- gamma, pretrained_model_path, save, min_reward, every_n_save
59
- ) for i in range(mp.cpu_count())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  ]
61
  [w.start() for w in workers]
62
  res = [] # record episode reward to plot
@@ -68,6 +92,8 @@ def train(
68
  break
69
  [w.join() for w in workers]
70
  if save:
71
- torch.save(gnet.state_dict(), os.path.join(
72
- model_checkpoint_dir, f'model_{env.unwrapped.spec.id}.pth'))
 
 
73
  return global_ep, win_ep, gnet, res
 
1
  import os
 
2
  import random
3
+
4
+ import numpy as np
5
  import torch
6
  import torch.multiprocessing as mp
7
+
8
  from .net import Net
9
+ from .shared_adam import SharedAdam
10
  from .worker import Worker
11
 
12
 
 
27
  env,
28
  max_ep,
29
  model_checkpoint_dir,
30
+ gamma=0.0,
31
  seed=100,
32
  pretrained_model_path=None,
33
  save=False,
34
  min_reward=9.9,
35
+ every_n_save=100,
36
  ):
37
  os.environ["OMP_NUM_THREADS"] = "1"
38
  if not os.path.exists(model_checkpoint_dir):
 
47
  if pretrained_model_path:
48
  gnet.load_state_dict(torch.load(pretrained_model_path))
49
  gnet.share_memory() # share the global parameters in multiprocessing
50
+ opt = SharedAdam(
51
+ gnet.parameters(), lr=1e-4, betas=(0.92, 0.999)
52
+ ) # global optimizer
53
+ global_ep, global_ep_r, res_queue, win_ep = (
54
+ mp.Value("i", 0),
55
+ mp.Value("d", 0.0),
56
+ mp.Queue(),
57
+ mp.Value("i", 0),
58
+ )
59
 
60
  # parallel training
61
  workers = [
62
  Worker(
63
+ max_ep,
64
+ gnet,
65
+ opt,
66
+ global_ep,
67
+ global_ep_r,
68
+ res_queue,
69
+ i,
70
+ env,
71
+ n_s,
72
+ n_a,
73
+ words_list,
74
+ word_width,
75
+ win_ep,
76
+ model_checkpoint_dir,
77
+ gamma,
78
+ pretrained_model_path,
79
+ save,
80
+ min_reward,
81
+ every_n_save,
82
+ )
83
+ for i in range(mp.cpu_count())
84
  ]
85
  [w.start() for w in workers]
86
  res = [] # record episode reward to plot
 
92
  break
93
  [w.join() for w in workers]
94
  if save:
95
+ torch.save(
96
+ gnet.state_dict(),
97
+ os.path.join(model_checkpoint_dir, f"model_{env.unwrapped.spec.id}.pth"),
98
+ )
99
  return global_ep, win_ep, gnet, res
a3c/utils.py CHANGED
@@ -1,5 +1,5 @@
1
- import torch
2
  import numpy as np
 
3
 
4
 
5
  def v_wrap(np_array, dtype=np.float32):
 
 
1
  import numpy as np
2
+ import torch
3
 
4
 
5
  def v_wrap(np_array, dtype=np.float32):
a3c/worker.py CHANGED
@@ -2,40 +2,42 @@
2
  Worker class implementation of the a3c discrete algorithm
3
  """
4
  import os
5
- import torch
6
  import numpy as np
 
7
  import torch.multiprocessing as mp
8
  from torch import nn
 
9
  from .net import Net
10
  from .utils import v_wrap
11
 
12
 
13
  class Worker(mp.Process):
14
  def __init__(
15
- self,
16
- max_ep,
17
- gnet,
18
- opt,
19
- global_ep,
20
- global_ep_r,
21
- res_queue,
22
- name,
23
- env,
24
- N_S,
25
- N_A,
26
- words_list,
27
- word_width,
28
- winning_ep,
29
- model_checkpoint_dir,
30
- gamma=0.,
31
- pretrained_model_path=None,
32
- save=False,
33
- min_reward=9.9,
34
- every_n_save=100
35
  ):
36
  super(Worker, self).__init__()
37
  self.max_ep = max_ep
38
- self.name = 'w%02i' % name
39
  self.g_ep = global_ep
40
  self.g_ep_r = global_ep_r
41
  self.res_queue = res_queue
@@ -57,7 +59,7 @@ class Worker(mp.Process):
57
  while self.g_ep.value < self.max_ep:
58
  s = self.env.reset()
59
  buffer_s, buffer_a, buffer_r = [], [], []
60
- ep_r = 0.
61
  while True:
62
  a = self.lnet.choose_action(v_wrap(s[None, :]))
63
  s_, r, done, _ = self.env.step(a)
@@ -68,11 +70,9 @@ class Worker(mp.Process):
68
 
69
  if done: # update global and assign to local net
70
  # sync
71
- self.push_and_pull(done, s_, buffer_s,
72
- buffer_a, buffer_r)
73
  goal_word = self.word_list[self.env.goal_word]
74
- self.record(ep_r, goal_word,
75
- self.word_list[a], len(buffer_a))
76
  self.save_model()
77
  buffer_s, buffer_a, buffer_r = [], [], []
78
  break
@@ -81,22 +81,22 @@ class Worker(mp.Process):
81
 
82
  def push_and_pull(self, done, s_, bs, ba, br):
83
  if done:
84
- v_s_ = 0. # terminal
85
  else:
86
- v_s_ = self.lnet.forward(v_wrap(
87
- s_[None, :]))[-1].data.numpy()[0, 0]
88
 
89
  buffer_v_target = []
90
- for r in br[::-1]: # reverse buffer r
91
  v_s_ = r + self.gamma * v_s_
92
  buffer_v_target.append(v_s_)
93
  buffer_v_target.reverse()
94
 
95
  loss = self.lnet.loss_func(
96
  v_wrap(np.vstack(bs)),
97
- v_wrap(np.array(ba), dtype=np.int64) if
98
- ba[0].dtype == np.int64 else v_wrap(np.vstack(ba)),
99
- v_wrap(np.array(buffer_v_target)[:, None])
 
100
  )
101
 
102
  # calculate local gradients and push local parameters to global
@@ -110,16 +110,21 @@ class Worker(mp.Process):
110
  self.lnet.load_state_dict(self.gnet.state_dict())
111
 
112
  def save_model(self):
113
- if (self.save and self.g_ep_r.value >= self.min_reward and
114
- self.g_ep.value % self.every_n_save == 0):
115
- torch.save(self.gnet.state_dict(), os.path.join(
116
- self.model_checkpoint_dir, f'model_{self.g_ep.value}.pth'))
 
 
 
 
 
117
 
118
  def record(self, ep_r, goal_word, action, action_number):
119
  with self.g_ep.get_lock():
120
  self.g_ep.value += 1
121
  with self.g_ep_r.get_lock():
122
- if self.g_ep_r.value == 0.:
123
  self.g_ep_r.value = ep_r
124
  else:
125
  self.g_ep_r.value = self.g_ep_r.value * 0.99 + ep_r * 0.01
@@ -129,9 +134,13 @@ class Worker(mp.Process):
129
  if self.g_ep.value % 100 == 0:
130
  print(
131
  self.name,
132
- "Ep:", self.g_ep.value,
 
133
  "| Ep_r: %.0f" % self.g_ep_r.value,
134
- "| Goal :", goal_word,
135
- "| Action: ", action,
136
- "| Actions: ", action_number
 
 
 
137
  )
 
2
  Worker class implementation of the a3c discrete algorithm
3
  """
4
  import os
5
+
6
  import numpy as np
7
+ import torch
8
  import torch.multiprocessing as mp
9
  from torch import nn
10
+
11
  from .net import Net
12
  from .utils import v_wrap
13
 
14
 
15
  class Worker(mp.Process):
16
  def __init__(
17
+ self,
18
+ max_ep,
19
+ gnet,
20
+ opt,
21
+ global_ep,
22
+ global_ep_r,
23
+ res_queue,
24
+ name,
25
+ env,
26
+ N_S,
27
+ N_A,
28
+ words_list,
29
+ word_width,
30
+ winning_ep,
31
+ model_checkpoint_dir,
32
+ gamma=0.0,
33
+ pretrained_model_path=None,
34
+ save=False,
35
+ min_reward=9.9,
36
+ every_n_save=100,
37
  ):
38
  super(Worker, self).__init__()
39
  self.max_ep = max_ep
40
+ self.name = "w%02i" % name
41
  self.g_ep = global_ep
42
  self.g_ep_r = global_ep_r
43
  self.res_queue = res_queue
 
59
  while self.g_ep.value < self.max_ep:
60
  s = self.env.reset()
61
  buffer_s, buffer_a, buffer_r = [], [], []
62
+ ep_r = 0.0
63
  while True:
64
  a = self.lnet.choose_action(v_wrap(s[None, :]))
65
  s_, r, done, _ = self.env.step(a)
 
70
 
71
  if done: # update global and assign to local net
72
  # sync
73
+ self.push_and_pull(done, s_, buffer_s, buffer_a, buffer_r)
 
74
  goal_word = self.word_list[self.env.goal_word]
75
+ self.record(ep_r, goal_word, self.word_list[a], len(buffer_a))
 
76
  self.save_model()
77
  buffer_s, buffer_a, buffer_r = [], [], []
78
  break
 
81
 
82
  def push_and_pull(self, done, s_, bs, ba, br):
83
  if done:
84
+ v_s_ = 0.0 # terminal
85
  else:
86
+ v_s_ = self.lnet.forward(v_wrap(s_[None, :]))[-1].data.numpy()[0, 0]
 
87
 
88
  buffer_v_target = []
89
+ for r in br[::-1]: # reverse buffer r
90
  v_s_ = r + self.gamma * v_s_
91
  buffer_v_target.append(v_s_)
92
  buffer_v_target.reverse()
93
 
94
  loss = self.lnet.loss_func(
95
  v_wrap(np.vstack(bs)),
96
+ v_wrap(np.array(ba), dtype=np.int64)
97
+ if ba[0].dtype == np.int64
98
+ else v_wrap(np.vstack(ba)),
99
+ v_wrap(np.array(buffer_v_target)[:, None]),
100
  )
101
 
102
  # calculate local gradients and push local parameters to global
 
110
  self.lnet.load_state_dict(self.gnet.state_dict())
111
 
112
  def save_model(self):
113
+ if (
114
+ self.save
115
+ and self.g_ep_r.value >= self.min_reward
116
+ and self.g_ep.value % self.every_n_save == 0
117
+ ):
118
+ torch.save(
119
+ self.gnet.state_dict(),
120
+ os.path.join(self.model_checkpoint_dir, f"model_{self.g_ep.value}.pth"),
121
+ )
122
 
123
  def record(self, ep_r, goal_word, action, action_number):
124
  with self.g_ep.get_lock():
125
  self.g_ep.value += 1
126
  with self.g_ep_r.get_lock():
127
+ if self.g_ep_r.value == 0.0:
128
  self.g_ep_r.value = ep_r
129
  else:
130
  self.g_ep_r.value = self.g_ep_r.value * 0.99 + ep_r * 0.01
 
134
  if self.g_ep.value % 100 == 0:
135
  print(
136
  self.name,
137
+ "Ep:",
138
+ self.g_ep.value,
139
  "| Ep_r: %.0f" % self.g_ep_r.value,
140
+ "| Goal :",
141
+ goal_word,
142
+ "| Action: ",
143
+ action,
144
+ "| Actions: ",
145
+ action_number,
146
  )
api_rest/api.py CHANGED
@@ -1,30 +1,32 @@
1
  import random
2
- from a3c.play import get_play_model_path, play
3
- from flask import Flask, request, jsonify
4
  from flask_cors import cross_origin
5
- from wordle_env.words import target_vocabulary
 
6
  from wordle_env.wordle import get_env
 
7
 
8
  app = Flask(__name__)
9
 
10
 
11
  def validate_goal_word(word):
12
  if not word:
13
- return True, 'Goal word not provided'
14
  if word.upper() not in target_vocabulary:
15
- return True, 'Goal word not in vocabulary'
16
- return False, ''
17
 
18
 
19
- @app.route('/play_word', methods=['GET'])
20
- @cross_origin(origin='*', headers=['Content-Type', 'Authorization'])
21
  def get_play():
22
  # Get the goal word from the request
23
- word = request.args.get('goal_word')
24
 
25
  error, msge = validate_goal_word(word)
26
  if error:
27
- return jsonify({'error': msge}), 400
28
 
29
  word = word.upper()
30
  env = get_env()
@@ -32,16 +34,16 @@ def get_play():
32
  # Call the play function with the goal word
33
  # and return the attempts and the result
34
  won, attempts = play(env, model_path, word)
35
- return jsonify({'attempts': attempts, 'won': won})
36
 
37
 
38
- @app.route('/word', methods=['GET'])
39
- @cross_origin(origin='*', headers=['Content-Type', 'Authorization'])
40
  def get_word():
41
  # Get a random word from the target vocabulary used to train the model
42
  word = random.choice(target_vocabulary)
43
  word = word.upper()
44
- return jsonify({'word': word})
45
 
46
 
47
  def create_app(settings_override=None):
@@ -58,5 +60,5 @@ def create_app(settings_override=None):
58
  return app
59
 
60
 
61
- if __name__ == '__main__':
62
  app.run(debug=True)
 
1
  import random
2
+
3
+ from flask import Flask, jsonify, request
4
  from flask_cors import cross_origin
5
+
6
+ from a3c.play import get_play_model_path, play
7
  from wordle_env.wordle import get_env
8
+ from wordle_env.words import target_vocabulary
9
 
10
  app = Flask(__name__)
11
 
12
 
13
  def validate_goal_word(word):
14
  if not word:
15
+ return True, "Goal word not provided"
16
  if word.upper() not in target_vocabulary:
17
+ return True, "Goal word not in vocabulary"
18
+ return False, ""
19
 
20
 
21
+ @app.route("/play_word", methods=["GET"])
22
+ @cross_origin(origin="*", headers=["Content-Type", "Authorization"])
23
  def get_play():
24
  # Get the goal word from the request
25
+ word = request.args.get("goal_word")
26
 
27
  error, msge = validate_goal_word(word)
28
  if error:
29
+ return jsonify({"error": msge}), 400
30
 
31
  word = word.upper()
32
  env = get_env()
 
34
  # Call the play function with the goal word
35
  # and return the attempts and the result
36
  won, attempts = play(env, model_path, word)
37
+ return jsonify({"attempts": attempts, "won": won})
38
 
39
 
40
+ @app.route("/word", methods=["GET"])
41
+ @cross_origin(origin="*", headers=["Content-Type", "Authorization"])
42
  def get_word():
43
  # Get a random word from the target vocabulary used to train the model
44
  word = random.choice(target_vocabulary)
45
  word = word.upper()
46
+ return jsonify({"word": word})
47
 
48
 
49
  def create_app(settings_override=None):
 
60
  return app
61
 
62
 
63
+ if __name__ == "__main__":
64
  app.run(debug=True)
main.py CHANGED
@@ -3,23 +3,33 @@
3
  import argparse
4
  import os
5
  import time
 
6
  import matplotlib.pyplot as plt
7
- from a3c.train import train
8
  from a3c.eval import evaluate, evaluate_checkpoints
9
  from a3c.play import suggest
 
10
  from wordle_env.wordle import get_env
11
 
12
 
13
  def training_mode(args, env, model_checkpoint_dir):
14
  max_ep = args.games
15
  start_time = time.time()
16
- pretrained_model_path = os.path.join(
17
- model_checkpoint_dir, args.model_name
18
- ) if args.model_name else args.model_name
 
 
19
  global_ep, win_ep, gnet, res = train(
20
- env, max_ep, model_checkpoint_dir, args.gamma,
21
- args.seed, pretrained_model_path, args.save,
22
- args.min_reward, args.every_n_save
 
 
 
 
 
 
23
  )
24
  print("--- %.0f seconds ---" % (time.time() - start_time))
25
  print_results(global_ep, win_ep, res)
@@ -34,8 +44,8 @@ def evaluation_mode(args, env, model_checkpoint_dir):
34
 
35
  def play_mode(args, env, model_checkpoint_dir):
36
  print("Play mode")
37
- words = [word.strip() for word in args.words.split(',')]
38
- states = [state.strip() for state in args.states.split(',')]
39
  pretrained_model_path = os.path.join(model_checkpoint_dir, args.model_name)
40
  word = suggest(env, words, states, pretrained_model_path)
41
  print(word)
@@ -45,8 +55,8 @@ def print_results(global_ep, win_ep, res):
45
  print("Jugadas:", global_ep.value)
46
  print("Ganadas:", win_ep.value)
47
  plt.plot(res)
48
- plt.ylabel('Moving average ep reward')
49
- plt.xlabel('Step')
50
  plt.show()
51
 
52
 
@@ -55,90 +65,80 @@ if __name__ == "__main__":
55
  parser.add_argument(
56
  "enviroment",
57
  help="Enviroment (type of wordle game) used for training, \
58
- example: WordleEnvFull-v0"
59
  )
60
  parser.add_argument(
61
  "--models_dir",
62
  help="Directory where models are saved (default=checkpoints)",
63
- default='checkpoints'
64
  )
65
- subparsers = parser.add_subparsers(help='sub-command help')
66
 
67
  parser_train = subparsers.add_parser(
68
- 'train',
69
- help='Train a model from scratch or train from pretrained model'
70
  )
71
  parser_train.add_argument(
72
- "--games",
73
- "-g",
74
- help="Number of games to train",
75
- type=int,
76
- required=True
77
  )
78
  parser_train.add_argument(
79
  "--model_name",
80
  "-m",
81
  help="If want to train from a pretrained model, \
82
- the name of the pretrained model file"
83
  )
84
  parser_train.add_argument(
85
  "--gamma",
86
  help="Gamma hyperparameter (discount factor) value",
87
  type=float,
88
- default=0.
89
  )
90
  parser_train.add_argument(
91
- "--seed",
92
- help="Seed used for random numbers generation",
93
- type=int,
94
- default=100
95
  )
96
  parser_train.add_argument(
97
  "--save",
98
- '-s',
99
  help="Save instances of the model while training",
100
- action='store_true'
101
  )
102
  parser_train.add_argument(
103
  "--min_reward",
104
  help="The minimun global reward value achieved for saving the model",
105
  type=float,
106
- default=9.9
107
  )
108
  parser_train.add_argument(
109
  "--every_n_save",
110
  help="Check every n training steps to save the model",
111
  type=int,
112
- default=100
113
  )
114
  parser_train.set_defaults(func=training_mode)
115
 
116
  parser_eval = subparsers.add_parser(
117
- 'eval', help='Evaluate saved models for the enviroment')
 
118
  parser_eval.set_defaults(func=evaluation_mode)
119
 
120
  parser_play = subparsers.add_parser(
121
- 'play',
122
- help='Give the model a word and the state result \
123
- and the model will try to predict the goal word'
124
  )
125
  parser_play.add_argument(
126
- "--words",
127
- "-w",
128
- help="List of words played in the wordle game",
129
- required=True
130
  )
131
  parser_play.add_argument(
132
  "--states",
133
  "-st",
134
  help="List of states returned by playing each of the words",
135
- required=True
136
  )
137
  parser_play.add_argument(
138
  "--model_name",
139
  "-m",
140
  help="Name of the pretrained model file thich will play the game",
141
- required=True
142
  )
143
  parser_play.set_defaults(func=play_mode)
144
 
 
3
  import argparse
4
  import os
5
  import time
6
+
7
  import matplotlib.pyplot as plt
8
+
9
  from a3c.eval import evaluate, evaluate_checkpoints
10
  from a3c.play import suggest
11
+ from a3c.train import train
12
  from wordle_env.wordle import get_env
13
 
14
 
15
  def training_mode(args, env, model_checkpoint_dir):
16
  max_ep = args.games
17
  start_time = time.time()
18
+ pretrained_model_path = (
19
+ os.path.join(model_checkpoint_dir, args.model_name)
20
+ if args.model_name
21
+ else args.model_name
22
+ )
23
  global_ep, win_ep, gnet, res = train(
24
+ env,
25
+ max_ep,
26
+ model_checkpoint_dir,
27
+ args.gamma,
28
+ args.seed,
29
+ pretrained_model_path,
30
+ args.save,
31
+ args.min_reward,
32
+ args.every_n_save,
33
  )
34
  print("--- %.0f seconds ---" % (time.time() - start_time))
35
  print_results(global_ep, win_ep, res)
 
44
 
45
  def play_mode(args, env, model_checkpoint_dir):
46
  print("Play mode")
47
+ words = [word.strip() for word in args.words.split(",")]
48
+ states = [state.strip() for state in args.states.split(",")]
49
  pretrained_model_path = os.path.join(model_checkpoint_dir, args.model_name)
50
  word = suggest(env, words, states, pretrained_model_path)
51
  print(word)
 
55
  print("Jugadas:", global_ep.value)
56
  print("Ganadas:", win_ep.value)
57
  plt.plot(res)
58
+ plt.ylabel("Moving average ep reward")
59
+ plt.xlabel("Step")
60
  plt.show()
61
 
62
 
 
65
  parser.add_argument(
66
  "enviroment",
67
  help="Enviroment (type of wordle game) used for training, \
68
+ example: WordleEnvFull-v0",
69
  )
70
  parser.add_argument(
71
  "--models_dir",
72
  help="Directory where models are saved (default=checkpoints)",
73
+ default="checkpoints",
74
  )
75
+ subparsers = parser.add_subparsers(help="sub-command help")
76
 
77
  parser_train = subparsers.add_parser(
78
+ "train", help="Train a model from scratch or train from pretrained model"
 
79
  )
80
  parser_train.add_argument(
81
+ "--games", "-g", help="Number of games to train", type=int, required=True
 
 
 
 
82
  )
83
  parser_train.add_argument(
84
  "--model_name",
85
  "-m",
86
  help="If want to train from a pretrained model, \
87
+ the name of the pretrained model file",
88
  )
89
  parser_train.add_argument(
90
  "--gamma",
91
  help="Gamma hyperparameter (discount factor) value",
92
  type=float,
93
+ default=0.0,
94
  )
95
  parser_train.add_argument(
96
+ "--seed", help="Seed used for random numbers generation", type=int, default=100
 
 
 
97
  )
98
  parser_train.add_argument(
99
  "--save",
100
+ "-s",
101
  help="Save instances of the model while training",
102
+ action="store_true",
103
  )
104
  parser_train.add_argument(
105
  "--min_reward",
106
  help="The minimun global reward value achieved for saving the model",
107
  type=float,
108
+ default=9.9,
109
  )
110
  parser_train.add_argument(
111
  "--every_n_save",
112
  help="Check every n training steps to save the model",
113
  type=int,
114
+ default=100,
115
  )
116
  parser_train.set_defaults(func=training_mode)
117
 
118
  parser_eval = subparsers.add_parser(
119
+ "eval", help="Evaluate saved models for the enviroment"
120
+ )
121
  parser_eval.set_defaults(func=evaluation_mode)
122
 
123
  parser_play = subparsers.add_parser(
124
+ "play",
125
+ help="Give the model a word and the state result \
126
+ and the model will try to predict the goal word",
127
  )
128
  parser_play.add_argument(
129
+ "--words", "-w", help="List of words played in the wordle game", required=True
 
 
 
130
  )
131
  parser_play.add_argument(
132
  "--states",
133
  "-st",
134
  help="List of states returned by playing each of the words",
135
+ required=True,
136
  )
137
  parser_play.add_argument(
138
  "--model_name",
139
  "-m",
140
  help="Name of the pretrained model file thich will play the game",
141
+ required=True,
142
  )
143
  parser_play.set_defaults(func=play_mode)
144
 
rs_wordle_player/firebase_connector.py CHANGED
@@ -1,13 +1,12 @@
1
  import os
2
- import firebase_admin
3
- from firebase_admin import credentials
4
- from firebase_admin import firestore
5
  from datetime import datetime
6
- from dotenv import load_dotenv
7
 
 
 
 
8
 
9
- class FirebaseConnector():
10
 
 
11
  def __init__(self):
12
  load_dotenv()
13
  cert_path = self.get_credentials_path()
@@ -20,32 +19,34 @@ class FirebaseConnector():
20
  return db
21
 
22
  def get_credentials_path(self):
23
- credentials_path = os.getenv('RS_FIREBASE_CREDENTIALS_PATH')
24
  return credentials_path
25
 
26
  def get_user(self):
27
- user = os.getenv('RS_WORDLE_USER')
28
  return user
29
 
30
  def get_state_from_fb_result(self, firebase_result):
31
- result_number_map = {'incorrect': '0',
32
- 'misplaced': '1',
33
- 'correct': '2'}
34
  char_result_map = map(
35
  lambda char_res: result_number_map[char_res], firebase_result
36
  )
37
- return ''.join(char_result_map)
38
 
39
  def today(self):
40
- return datetime.today().strftime('%Y%m%d')
41
 
42
  def today_user_results(self):
43
- daily_results_col = 'dailyResults'
44
  currentUser = self.get_user()
45
  # Execute the query and get the first result
46
- docs = self.db.collection(daily_results_col).where(
47
- 'user.email', '==', currentUser).where(
48
- 'date', '==', self.today()).limit(1).get()
 
 
 
 
49
  return docs
50
 
51
  def today_user_attempts(self):
@@ -53,10 +54,10 @@ class FirebaseConnector():
53
  attempted_words = []
54
  if len(docs) > 0:
55
  doc = docs[0]
56
- attempted_words = doc.to_dict().get('attemptedWords')
57
  return attempted_words
58
 
59
  def today_word(self):
60
- words_col = 'words'
61
  doc = self.db.collection(words_col).document(self.today())
62
- return doc.get().get('word')
 
1
  import os
 
 
 
2
  from datetime import datetime
 
3
 
4
+ import firebase_admin
5
+ from dotenv import load_dotenv
6
+ from firebase_admin import credentials, firestore
7
 
 
8
 
9
+ class FirebaseConnector:
10
  def __init__(self):
11
  load_dotenv()
12
  cert_path = self.get_credentials_path()
 
19
  return db
20
 
21
  def get_credentials_path(self):
22
+ credentials_path = os.getenv("RS_FIREBASE_CREDENTIALS_PATH")
23
  return credentials_path
24
 
25
  def get_user(self):
26
+ user = os.getenv("RS_WORDLE_USER")
27
  return user
28
 
29
  def get_state_from_fb_result(self, firebase_result):
30
+ result_number_map = {"incorrect": "0", "misplaced": "1", "correct": "2"}
 
 
31
  char_result_map = map(
32
  lambda char_res: result_number_map[char_res], firebase_result
33
  )
34
+ return "".join(char_result_map)
35
 
36
  def today(self):
37
+ return datetime.today().strftime("%Y%m%d")
38
 
39
  def today_user_results(self):
40
+ daily_results_col = "dailyResults"
41
  currentUser = self.get_user()
42
  # Execute the query and get the first result
43
+ docs = (
44
+ self.db.collection(daily_results_col)
45
+ .where("user.email", "==", currentUser)
46
+ .where("date", "==", self.today())
47
+ .limit(1)
48
+ .get()
49
+ )
50
  return docs
51
 
52
  def today_user_attempts(self):
 
54
  attempted_words = []
55
  if len(docs) > 0:
56
  doc = docs[0]
57
+ attempted_words = doc.to_dict().get("attemptedWords")
58
  return attempted_words
59
 
60
  def today_word(self):
61
+ words_col = "words"
62
  doc = self.db.collection(words_col).document(self.today())
63
+ return doc.get().get("word")
rs_wordle_player/rs_wordle_player.py CHANGED
@@ -1,5 +1,6 @@
1
  from a3c.play import get_play_model_path, suggest
2
  from wordle_env.wordle import get_env
 
3
  from .firebase_connector import FirebaseConnector
4
  from .selenium_player import SeleniumPlayer
5
 
@@ -17,7 +18,7 @@ def get_attempts(fb_connector):
17
 
18
  def is_game_finished(states):
19
  if states:
20
- return states[-1] == '22222' or len(states) == 6
21
  return False
22
 
23
 
@@ -49,5 +50,5 @@ def play():
49
  return words, won
50
 
51
 
52
- if __name__ == '__main__':
53
  print(play())
 
1
  from a3c.play import get_play_model_path, suggest
2
  from wordle_env.wordle import get_env
3
+
4
  from .firebase_connector import FirebaseConnector
5
  from .selenium_player import SeleniumPlayer
6
 
 
18
 
19
  def is_game_finished(states):
20
  if states:
21
+ return states[-1] == "22222" or len(states) == 6
22
  return False
23
 
24
 
 
50
  return words, won
51
 
52
 
53
+ if __name__ == "__main__":
54
  print(play())
rs_wordle_player/selenium_player.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import time
 
3
  from dotenv import load_dotenv
4
  from selenium import webdriver
5
  from selenium.common.exceptions import UnexpectedAlertPresentException
@@ -8,8 +9,7 @@ from selenium.webdriver.common.by import By
8
  from selenium.webdriver.common.keys import Keys
9
 
10
 
11
- class SeleniumPlayer():
12
-
13
  def __init__(self):
14
  self.wordle_url = self.get_wordle_url()
15
  self.driver = self.get_driver()
@@ -24,22 +24,22 @@ class SeleniumPlayer():
24
 
25
  def get_wordle_url(self):
26
  load_dotenv()
27
- return os.getenv('RS_WORDLE_URL')
28
 
29
  def get_credentials(self):
30
  load_dotenv()
31
- username = os.getenv('RS_WORDLE_USER')
32
- password = os.getenv('RS_WORDLE_PASSWORD')
33
  return username, password
34
 
35
  def logged_in(self):
36
- return self.driver.current_url != self.wordle_url + '/login'
37
 
38
  def log_in(self):
39
  if not self.logged_in():
40
  time.sleep(2)
41
- login_div = self.driver.find_element(By.CLASS_NAME, 'login-button')
42
- login_btns = login_div.find_elements(By.TAG_NAME, 'button')
43
  login_btn = login_btns[0]
44
  login_btn.click()
45
  time.sleep(10)
@@ -47,32 +47,31 @@ class SeleniumPlayer():
47
  login_window = self.driver.window_handles[1]
48
  self.driver.switch_to.window(login_window)
49
  username, password = self.get_credentials()
50
- element = self.driver.find_element(By.ID, 'identifierId')
51
  element.send_keys(username)
52
  element.send_keys(Keys.ENTER)
53
  time.sleep(10)
54
- element = self.driver.find_element(By.NAME, 'password')
55
  element.send_keys(password)
56
  element.send_keys(Keys.ENTER)
57
  self.driver.switch_to.window(wordle_window)
58
  time.sleep(5)
59
  onboard_div = self.driver.find_element(
60
- By.CLASS_NAME,
61
- 'onboarding-modal-container'
62
  )
63
- onboard_btn = onboard_div.find_elements(By.TAG_NAME, 'button')
64
  onboard_btn[-1].click()
65
 
66
  def play_word(self, word):
67
  try:
68
- element = self.driver.find_element(By.TAG_NAME, 'html')
69
  # simulate typing the letters in the word into the input field
70
  element.send_keys(word)
71
  # simulate pressing the Enter key
72
  element.send_keys(Keys.ENTER)
73
  time.sleep(5)
74
  except UnexpectedAlertPresentException:
75
- print('Won game alert on screen')
76
 
77
  def finish(self):
78
  self.driver.quit()
 
1
  import os
2
  import time
3
+
4
  from dotenv import load_dotenv
5
  from selenium import webdriver
6
  from selenium.common.exceptions import UnexpectedAlertPresentException
 
9
  from selenium.webdriver.common.keys import Keys
10
 
11
 
12
+ class SeleniumPlayer:
 
13
  def __init__(self):
14
  self.wordle_url = self.get_wordle_url()
15
  self.driver = self.get_driver()
 
24
 
25
  def get_wordle_url(self):
26
  load_dotenv()
27
+ return os.getenv("RS_WORDLE_URL")
28
 
29
  def get_credentials(self):
30
  load_dotenv()
31
+ username = os.getenv("RS_WORDLE_USER")
32
+ password = os.getenv("RS_WORDLE_PASSWORD")
33
  return username, password
34
 
35
  def logged_in(self):
36
+ return self.driver.current_url != self.wordle_url + "/login"
37
 
38
  def log_in(self):
39
  if not self.logged_in():
40
  time.sleep(2)
41
+ login_div = self.driver.find_element(By.CLASS_NAME, "login-button")
42
+ login_btns = login_div.find_elements(By.TAG_NAME, "button")
43
  login_btn = login_btns[0]
44
  login_btn.click()
45
  time.sleep(10)
 
47
  login_window = self.driver.window_handles[1]
48
  self.driver.switch_to.window(login_window)
49
  username, password = self.get_credentials()
50
+ element = self.driver.find_element(By.ID, "identifierId")
51
  element.send_keys(username)
52
  element.send_keys(Keys.ENTER)
53
  time.sleep(10)
54
+ element = self.driver.find_element(By.NAME, "password")
55
  element.send_keys(password)
56
  element.send_keys(Keys.ENTER)
57
  self.driver.switch_to.window(wordle_window)
58
  time.sleep(5)
59
  onboard_div = self.driver.find_element(
60
+ By.CLASS_NAME, "onboarding-modal-container"
 
61
  )
62
+ onboard_btn = onboard_div.find_elements(By.TAG_NAME, "button")
63
  onboard_btn[-1].click()
64
 
65
  def play_word(self, word):
66
  try:
67
+ element = self.driver.find_element(By.TAG_NAME, "html")
68
  # simulate typing the letters in the word into the input field
69
  element.send_keys(word)
70
  # simulate pressing the Enter key
71
  element.send_keys(Keys.ENTER)
72
  time.sleep(5)
73
  except UnexpectedAlertPresentException:
74
+ print("Won game alert on screen")
75
 
76
  def finish(self):
77
  self.driver.quit()
wordle_env/__init__.py CHANGED
@@ -1,13 +1,9 @@
1
- from gym.envs.registration import (
2
- registry,
3
- register,
4
- make,
5
- spec,
6
- load_env_plugins as _load_env_plugins,
7
- )
8
  import os
9
- from . import wordle
10
 
 
 
 
 
11
 
12
  register(
13
  id="WordleEnv100OneAction-v0",
 
 
 
 
 
 
 
 
1
  import os
 
2
 
3
+ from gym.envs.registration import load_env_plugins as _load_env_plugins
4
+ from gym.envs.registration import make, register, registry, spec
5
+
6
+ from . import wordle
7
 
8
  register(
9
  id="WordleEnv100OneAction-v0",
wordle_env/const.py CHANGED
@@ -1,4 +1,4 @@
1
- WORDLE_CHARS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
2
  WORDLE_N = 5
3
  REWARD = 10
4
  CHAR_REWARD = 0.1
 
1
+ WORDLE_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
2
  WORDLE_N = 5
3
  REWARD = 10
4
  CHAR_REWARD = 0.1
wordle_env/state.py CHANGED
@@ -13,11 +13,11 @@ where status has codes
13
  """
14
  import collections
15
  from typing import List, Tuple
 
16
  import numpy as np
17
 
18
  from .const import CHAR_REWARD, WORDLE_CHARS, WORDLE_N
19
 
20
-
21
  WordleState = np.ndarray
22
 
23
 
@@ -27,8 +27,8 @@ def get_nvec(max_turns: int):
27
 
28
  def new(max_turns: int) -> WordleState:
29
  return np.array(
30
- [max_turns] + [0, 0, 0] * WORDLE_N * len(WORDLE_CHARS),
31
- dtype=np.int32)
32
 
33
 
34
  def remaining_steps(state: WordleState) -> int:
@@ -40,11 +40,7 @@ SOMEWHERE = 1
40
  YES = 2
41
 
42
 
43
- def update_from_mask(
44
- state: WordleState,
45
- word: str,
46
- mask: List[int]
47
- ) -> WordleState:
48
  """
49
  return a copy of state that has been updated to new state
50
 
@@ -84,14 +80,14 @@ def update_from_mask(
84
  # Need to check this first in case there's prior maybe + yes
85
  if c in prior_maybe:
86
  # Then the maybe could be anywhere except here
87
- state[offset+3*i:offset+3*i+3] = [1, 0, 0]
88
  elif c in prior_yes:
89
  # No maybe, definitely a yes,
90
  # so it's zero everywhere except the yesses
91
  for j in range(WORDLE_N):
92
  # Only flip no if previously was maybe
93
- if state[offset + 3 * j:offset + 3 * j + 3][1] == 1:
94
- state[offset + 3 * j:offset + 3 * j + 3] = [1, 0, 0]
95
  else:
96
  # Just straight up no
97
  _set_all_no(state, offset)
@@ -115,7 +111,7 @@ def get_mask(word: str, goal_word: str) -> List[int]:
115
  mask[i] = 1
116
  counts[c] -= 1
117
  else:
118
- for j in range(i+1, len(mask)):
119
  if mask[j] == 2:
120
  continue
121
  mask[j] = 0
@@ -136,11 +132,7 @@ def update_mask(state: WordleState, word: str, goal_word: str) -> WordleState:
136
  return update_from_mask(state, word, mask)
137
 
138
 
139
- def update(
140
- state: WordleState,
141
- word: str,
142
- goal_word: str
143
- ) -> Tuple[WordleState, float]:
144
  state = state.copy()
145
  reward = 0
146
  state[0] -= 1
@@ -158,8 +150,7 @@ def update(
158
  cint = ord(c) - ord(WORDLE_CHARS[0])
159
  offset = 1 + cint * WORDLE_N * 3
160
  if goal_word[i] != c:
161
- if (c in goal_word and
162
- goal_word.count(c) > processed_letters.count(c)):
163
  # Char at position i = no,
164
  # and in other positions maybe except it had a value before,
165
  # other chars stay as they are
@@ -184,27 +175,27 @@ def _set_if_cero(state, offset, value):
184
  # but only if it didnt have a value before
185
  for char_idx in range(0, WORDLE_N * 3, 3):
186
  char_offset = offset + char_idx
187
- if tuple(state[char_offset: char_offset + 3]) == (0, 0, 0):
188
- state[char_offset: char_offset + 3] = value
189
 
190
 
191
  def _set_yes(state, offset, char_int, char_pos):
192
  # char at position char_pos = yes,
193
  # all other chars at position char_pos == no
194
  pos_offset = 3 * char_pos
195
- state[offset + pos_offset:offset + pos_offset + 3] = [0, 0, 1]
196
  for ocint in range(len(WORDLE_CHARS)):
197
  if ocint != char_int:
198
  oc_offset = 1 + ocint * WORDLE_N * 3
199
  yes_index = oc_offset + pos_offset
200
- state[yes_index:yes_index + 3] = [1, 0, 0]
201
 
202
 
203
  def _set_no(state, offset, char_pos):
204
  # Set offset character = no at char_pos position
205
- state[offset + 3 * char_pos:offset + 3 * char_pos + 3] = [1, 0, 0]
206
 
207
 
208
  def _set_all_no(state, offset):
209
  # Set offset character = no at all positions
210
- state[offset:offset + 3 * WORDLE_N] = [1, 0, 0] * WORDLE_N
 
13
  """
14
  import collections
15
  from typing import List, Tuple
16
+
17
  import numpy as np
18
 
19
  from .const import CHAR_REWARD, WORDLE_CHARS, WORDLE_N
20
 
 
21
  WordleState = np.ndarray
22
 
23
 
 
27
 
28
  def new(max_turns: int) -> WordleState:
29
  return np.array(
30
+ [max_turns] + [0, 0, 0] * WORDLE_N * len(WORDLE_CHARS), dtype=np.int32
31
+ )
32
 
33
 
34
  def remaining_steps(state: WordleState) -> int:
 
40
  YES = 2
41
 
42
 
43
+ def update_from_mask(state: WordleState, word: str, mask: List[int]) -> WordleState:
 
 
 
 
44
  """
45
  return a copy of state that has been updated to new state
46
 
 
80
  # Need to check this first in case there's prior maybe + yes
81
  if c in prior_maybe:
82
  # Then the maybe could be anywhere except here
83
+ state[offset + 3 * i : offset + 3 * i + 3] = [1, 0, 0]
84
  elif c in prior_yes:
85
  # No maybe, definitely a yes,
86
  # so it's zero everywhere except the yesses
87
  for j in range(WORDLE_N):
88
  # Only flip no if previously was maybe
89
+ if state[offset + 3 * j : offset + 3 * j + 3][1] == 1:
90
+ state[offset + 3 * j : offset + 3 * j + 3] = [1, 0, 0]
91
  else:
92
  # Just straight up no
93
  _set_all_no(state, offset)
 
111
  mask[i] = 1
112
  counts[c] -= 1
113
  else:
114
+ for j in range(i + 1, len(mask)):
115
  if mask[j] == 2:
116
  continue
117
  mask[j] = 0
 
132
  return update_from_mask(state, word, mask)
133
 
134
 
135
+ def update(state: WordleState, word: str, goal_word: str) -> Tuple[WordleState, float]:
 
 
 
 
136
  state = state.copy()
137
  reward = 0
138
  state[0] -= 1
 
150
  cint = ord(c) - ord(WORDLE_CHARS[0])
151
  offset = 1 + cint * WORDLE_N * 3
152
  if goal_word[i] != c:
153
+ if c in goal_word and goal_word.count(c) > processed_letters.count(c):
 
154
  # Char at position i = no,
155
  # and in other positions maybe except it had a value before,
156
  # other chars stay as they are
 
175
  # but only if it didnt have a value before
176
  for char_idx in range(0, WORDLE_N * 3, 3):
177
  char_offset = offset + char_idx
178
+ if tuple(state[char_offset : char_offset + 3]) == (0, 0, 0):
179
+ state[char_offset : char_offset + 3] = value
180
 
181
 
182
  def _set_yes(state, offset, char_int, char_pos):
183
  # char at position char_pos = yes,
184
  # all other chars at position char_pos == no
185
  pos_offset = 3 * char_pos
186
+ state[offset + pos_offset : offset + pos_offset + 3] = [0, 0, 1]
187
  for ocint in range(len(WORDLE_CHARS)):
188
  if ocint != char_int:
189
  oc_offset = 1 + ocint * WORDLE_N * 3
190
  yes_index = oc_offset + pos_offset
191
+ state[yes_index : yes_index + 3] = [1, 0, 0]
192
 
193
 
194
  def _set_no(state, offset, char_pos):
195
  # Set offset character = no at char_pos position
196
+ state[offset + 3 * char_pos : offset + 3 * char_pos + 3] = [1, 0, 0]
197
 
198
 
199
  def _set_all_no(state, offset):
200
  # Set offset character = no at all positions
201
+ state[offset : offset + 3 * WORDLE_N] = [1, 0, 0] * WORDLE_N
wordle_env/test_wordle.py CHANGED
@@ -1,7 +1,6 @@
1
  import pytest
2
 
3
- from . import wordle
4
- from . import state
5
 
6
  TESTWORDS = [
7
  "APPAA",
 
1
  import pytest
2
 
3
+ from . import state, wordle
 
4
 
5
  TESTWORDS = [
6
  "APPAA",
wordle_env/wordle.py CHANGED
@@ -1,24 +1,22 @@
 
 
 
1
  import gym
2
  from gym import spaces
3
- from typing import Optional, List
4
 
5
  from . import state
6
- from .const import WORDLE_N, REWARD, WORDLE_CHARS
7
-
8
  from .words import complete_vocabulary, target_vocabulary
9
 
10
- import random
11
-
12
 
13
  def _load_words(
14
- limit: Optional[int] = None,
15
- complete: Optional[bool] = False
16
  ) -> List[str]:
17
  words = complete_vocabulary if complete else target_vocabulary
18
  return words if not limit else words[:limit]
19
 
20
 
21
- def get_env(env_id='WordleEnvFull-v0'):
22
  return gym.make(env_id)
23
 
24
 
@@ -42,13 +40,16 @@ class WordleEnvBase(gym.Env):
42
  Initial state with turn 0, all chars Unvisited
43
  """
44
 
45
- def __init__(self, words: List[str],
46
- max_turns: int = 6,
47
- allowable_words: Optional[int] = None,
48
- mask_based_state_updates: bool = False):
 
 
 
49
  assert all(
50
  len(w) == WORDLE_N for w in words
51
- ), f'Not all words of length {WORDLE_N}, {words}'
52
  self.words = words
53
  self.max_turns = max_turns
54
  self.allowable_words = allowable_words
@@ -57,8 +58,7 @@ class WordleEnvBase(gym.Env):
57
  self.allowable_words = len(self.words)
58
 
59
  self.action_space = spaces.Discrete(self.words_as_action_space())
60
- self.observation_space = spaces.MultiDiscrete(
61
- state.get_nvec(self.max_turns))
62
 
63
  self.done = True
64
  self.goal_word: int = -1
@@ -79,15 +79,15 @@ class WordleEnvBase(gym.Env):
79
  word = self.words[action]
80
  goal_word = self.words[self.goal_word]
81
  # assert word in self.words, f'{word} not in words list'
82
- self.state, r = self.state_updater(state=self.state,
83
- word=word,
84
- goal_word=goal_word)
85
 
86
  reward = r
87
  if action == self.goal_word:
88
  self.done = True
89
  # reward = REWARD
90
- if state.remaining_steps(self.state) == self.max_turns-1:
91
  reward = 0 # -10*REWARD # No reward for guessing off the bat
92
  else:
93
  reward = REWARD
@@ -100,7 +100,7 @@ class WordleEnvBase(gym.Env):
100
  def reset(self):
101
  self.state = state.new(self.max_turns)
102
  self.done = False
103
- random_word = random.choice(self.words[:self.allowable_words])
104
  self.goal_word = self.words.index(random_word)
105
  return self.state.copy()
106
 
@@ -121,8 +121,7 @@ class WordleEnv100OneAction(WordleEnvBase):
121
 
122
  class WordleEnv100WithMask(WordleEnvBase):
123
  def __init__(self):
124
- super().__init__(words=_load_words(100),
125
- mask_based_state_updates=True)
126
 
127
 
128
  class WordleEnv100TwoAction(WordleEnvBase):
@@ -142,8 +141,7 @@ class WordleEnv100FullAction(WordleEnvBase):
142
 
143
  class WordleEnv1000WithMask(WordleEnvBase):
144
  def __init__(self):
145
- super().__init__(words=_load_words(1000),
146
- mask_based_state_updates=True)
147
 
148
 
149
  class WordleEnv1000FullAction(WordleEnvBase):
@@ -158,5 +156,6 @@ class WordleEnvFull(WordleEnvBase):
158
 
159
  class WordleEnvRealWithMask(WordleEnvBase):
160
  def __init__(self):
161
- super().__init__(words=_load_words(), allowable_words=2315,
162
- mask_based_state_updates=True)
 
 
1
+ import random
2
+ from typing import List, Optional
3
+
4
  import gym
5
  from gym import spaces
 
6
 
7
  from . import state
8
+ from .const import REWARD, WORDLE_CHARS, WORDLE_N
 
9
  from .words import complete_vocabulary, target_vocabulary
10
 
 
 
11
 
12
  def _load_words(
13
+ limit: Optional[int] = None, complete: Optional[bool] = False
 
14
  ) -> List[str]:
15
  words = complete_vocabulary if complete else target_vocabulary
16
  return words if not limit else words[:limit]
17
 
18
 
19
+ def get_env(env_id="WordleEnvFull-v0"):
20
  return gym.make(env_id)
21
 
22
 
 
40
  Initial state with turn 0, all chars Unvisited
41
  """
42
 
43
+ def __init__(
44
+ self,
45
+ words: List[str],
46
+ max_turns: int = 6,
47
+ allowable_words: Optional[int] = None,
48
+ mask_based_state_updates: bool = False,
49
+ ):
50
  assert all(
51
  len(w) == WORDLE_N for w in words
52
+ ), f"Not all words of length {WORDLE_N}, {words}"
53
  self.words = words
54
  self.max_turns = max_turns
55
  self.allowable_words = allowable_words
 
58
  self.allowable_words = len(self.words)
59
 
60
  self.action_space = spaces.Discrete(self.words_as_action_space())
61
+ self.observation_space = spaces.MultiDiscrete(state.get_nvec(self.max_turns))
 
62
 
63
  self.done = True
64
  self.goal_word: int = -1
 
79
  word = self.words[action]
80
  goal_word = self.words[self.goal_word]
81
  # assert word in self.words, f'{word} not in words list'
82
+ self.state, r = self.state_updater(
83
+ state=self.state, word=word, goal_word=goal_word
84
+ )
85
 
86
  reward = r
87
  if action == self.goal_word:
88
  self.done = True
89
  # reward = REWARD
90
+ if state.remaining_steps(self.state) == self.max_turns - 1:
91
  reward = 0 # -10*REWARD # No reward for guessing off the bat
92
  else:
93
  reward = REWARD
 
100
  def reset(self):
101
  self.state = state.new(self.max_turns)
102
  self.done = False
103
+ random_word = random.choice(self.words[: self.allowable_words])
104
  self.goal_word = self.words.index(random_word)
105
  return self.state.copy()
106
 
 
121
 
122
  class WordleEnv100WithMask(WordleEnvBase):
123
  def __init__(self):
124
+ super().__init__(words=_load_words(100), mask_based_state_updates=True)
 
125
 
126
 
127
  class WordleEnv100TwoAction(WordleEnvBase):
 
141
 
142
  class WordleEnv1000WithMask(WordleEnvBase):
143
  def __init__(self):
144
+ super().__init__(words=_load_words(1000), mask_based_state_updates=True)
 
145
 
146
 
147
  class WordleEnv1000FullAction(WordleEnvBase):
 
156
 
157
  class WordleEnvRealWithMask(WordleEnvBase):
158
  def __init__(self):
159
+ super().__init__(
160
+ words=_load_words(), allowable_words=2315, mask_based_state_updates=True
161
+ )
wordle_env/words.py CHANGED
@@ -7,7 +7,7 @@ _COMPLETE_VOCABULARY_URL = "https://gist.githubusercontent.com/scholtes/\
7
  _TARGET_VOCABULARY_URL = "https://gist.githubusercontent.com/scholtes/\
8
  94f3c0303ba6a7768b47583aff36654d/raw/\
9
  d9cddf5e16140df9e14f19c2de76a0ef36fd2748/wordle-La.txt"
10
- _DOWNLOADS_DIR = '.'
11
  _COMPLETE_VOCABULARY_FILENAME = "complete_vocabulary.txt"
12
  _TARGET_VOCABULARY_FILENAME = "target_vocabulary.txt"
13
 
@@ -24,7 +24,11 @@ def _retrieve_vocabulary(url, filename, dir):
24
 
25
 
26
  target_vocabulary = _retrieve_vocabulary(
27
- _TARGET_VOCABULARY_URL, _TARGET_VOCABULARY_FILENAME, _DOWNLOADS_DIR)
28
- complete_vocabulary = _retrieve_vocabulary(
29
- _COMPLETE_VOCABULARY_URL, _COMPLETE_VOCABULARY_FILENAME, _DOWNLOADS_DIR
30
- ) + target_vocabulary
 
 
 
 
 
7
  _TARGET_VOCABULARY_URL = "https://gist.githubusercontent.com/scholtes/\
8
  94f3c0303ba6a7768b47583aff36654d/raw/\
9
  d9cddf5e16140df9e14f19c2de76a0ef36fd2748/wordle-La.txt"
10
+ _DOWNLOADS_DIR = "."
11
  _COMPLETE_VOCABULARY_FILENAME = "complete_vocabulary.txt"
12
  _TARGET_VOCABULARY_FILENAME = "target_vocabulary.txt"
13
 
 
24
 
25
 
26
  target_vocabulary = _retrieve_vocabulary(
27
+ _TARGET_VOCABULARY_URL, _TARGET_VOCABULARY_FILENAME, _DOWNLOADS_DIR
28
+ )
29
+ complete_vocabulary = (
30
+ _retrieve_vocabulary(
31
+ _COMPLETE_VOCABULARY_URL, _COMPLETE_VOCABULARY_FILENAME, _DOWNLOADS_DIR
32
+ )
33
+ + target_vocabulary
34
+ )
wordle_game.py CHANGED
@@ -1,30 +1,28 @@
1
- from rich.prompt import Prompt
2
- from rich.console import Console
3
  from random import choice
4
- from wordle_env.words import target_vocabulary, complete_vocabulary
5
 
6
- SQUARES = {
7
- 'correct_place': '🟩',
8
- 'correct_letter': '🟨',
9
- 'incorrect_letter': '⬛'
10
- }
 
11
 
12
- WELCOME_MESSAGE = f'\n[white on blue] WELCOME TO WORDLE [/]\n'
13
  PLAYER_INSTRUCTIONS = "You may start guessing\n"
14
  GUESS_STATEMENT = "\nEnter your guess"
15
  ALLOWED_GUESSES = 6
16
 
17
 
18
  def correct_place(letter):
19
- return f'[black on green]{letter}[/]'
20
 
21
 
22
  def correct_letter(letter):
23
- return f'[black on yellow]{letter}[/]'
24
 
25
 
26
  def incorrect_letter(letter):
27
- return f'[black on white]{letter}[/]'
28
 
29
 
30
  def check_guess(guess, answer):
@@ -34,19 +32,20 @@ def check_guess(guess, answer):
34
  for i, letter in enumerate(guess):
35
  if answer[i] == guess[i]:
36
  guessed[i] = correct_place(letter)
37
- wordle_pattern.append(SQUARES['correct_place'])
38
  processed_letters.append(letter)
39
  for i, letter in enumerate(guess):
40
  if answer[i] != guess[i]:
41
- if (letter in answer and
42
- answer.count(letter) > processed_letters.count(letter)):
 
43
  guessed[i] = correct_letter(letter)
44
- wordle_pattern.append(SQUARES['correct_letter'])
45
  else:
46
  guessed[i] = incorrect_letter(letter)
47
- wordle_pattern.append(SQUARES['incorrect_letter'])
48
  processed_letters.append(letter)
49
- return ''.join(guessed), ''.join(wordle_pattern)
50
 
51
 
52
  def game(console, chosen_word):
@@ -57,12 +56,15 @@ def game(console, chosen_word):
57
 
58
  while not end_of_game:
59
  guess = Prompt.ask(GUESS_STATEMENT).upper()
60
- while (len(guess) != 5 or guess in already_guessed or
61
- guess not in complete_vocabulary):
 
 
 
62
  if guess in already_guessed:
63
  console.print("[red]You've already guessed this word!!\n[/]")
64
  else:
65
- console.print('[red]Please enter a valid 5-letter word!!\n[/]')
66
  guess = Prompt.ask(GUESS_STATEMENT).upper()
67
  already_guessed.append(guess)
68
  guessed, pattern = check_guess(guess, chosen_word)
@@ -74,14 +76,13 @@ def game(console, chosen_word):
74
  end_of_game = True
75
  if len(already_guessed) == ALLOWED_GUESSES and guess != chosen_word:
76
  console.print(f"\n[red]WORDLE X/{ALLOWED_GUESSES}[/]")
77
- console.print(f'\n[green]Correct Word: {chosen_word}[/]')
78
  else:
79
- console.print(
80
- f"\n[green]WORDLE {len(already_guessed)}/{ALLOWED_GUESSES}[/]\n")
81
  console.print(*full_wordle_pattern, sep="\n")
82
 
83
 
84
- if __name__ == '__main__':
85
  console = Console()
86
  chosen_word = choice(target_vocabulary)
87
  console.print(WELCOME_MESSAGE)
 
 
 
1
  from random import choice
 
2
 
3
+ from rich.console import Console
4
+ from rich.prompt import Prompt
5
+
6
+ from wordle_env.words import complete_vocabulary, target_vocabulary
7
+
8
+ SQUARES = {"correct_place": "🟩", "correct_letter": "🟨", "incorrect_letter": "⬛"}
9
 
10
+ WELCOME_MESSAGE = f"\n[white on blue] WELCOME TO WORDLE [/]\n"
11
  PLAYER_INSTRUCTIONS = "You may start guessing\n"
12
  GUESS_STATEMENT = "\nEnter your guess"
13
  ALLOWED_GUESSES = 6
14
 
15
 
16
  def correct_place(letter):
17
+ return f"[black on green]{letter}[/]"
18
 
19
 
20
  def correct_letter(letter):
21
+ return f"[black on yellow]{letter}[/]"
22
 
23
 
24
  def incorrect_letter(letter):
25
+ return f"[black on white]{letter}[/]"
26
 
27
 
28
  def check_guess(guess, answer):
 
32
  for i, letter in enumerate(guess):
33
  if answer[i] == guess[i]:
34
  guessed[i] = correct_place(letter)
35
+ wordle_pattern.append(SQUARES["correct_place"])
36
  processed_letters.append(letter)
37
  for i, letter in enumerate(guess):
38
  if answer[i] != guess[i]:
39
+ if letter in answer and answer.count(letter) > processed_letters.count(
40
+ letter
41
+ ):
42
  guessed[i] = correct_letter(letter)
43
+ wordle_pattern.append(SQUARES["correct_letter"])
44
  else:
45
  guessed[i] = incorrect_letter(letter)
46
+ wordle_pattern.append(SQUARES["incorrect_letter"])
47
  processed_letters.append(letter)
48
+ return "".join(guessed), "".join(wordle_pattern)
49
 
50
 
51
  def game(console, chosen_word):
 
56
 
57
  while not end_of_game:
58
  guess = Prompt.ask(GUESS_STATEMENT).upper()
59
+ while (
60
+ len(guess) != 5
61
+ or guess in already_guessed
62
+ or guess not in complete_vocabulary
63
+ ):
64
  if guess in already_guessed:
65
  console.print("[red]You've already guessed this word!!\n[/]")
66
  else:
67
+ console.print("[red]Please enter a valid 5-letter word!!\n[/]")
68
  guess = Prompt.ask(GUESS_STATEMENT).upper()
69
  already_guessed.append(guess)
70
  guessed, pattern = check_guess(guess, chosen_word)
 
76
  end_of_game = True
77
  if len(already_guessed) == ALLOWED_GUESSES and guess != chosen_word:
78
  console.print(f"\n[red]WORDLE X/{ALLOWED_GUESSES}[/]")
79
+ console.print(f"\n[green]Correct Word: {chosen_word}[/]")
80
  else:
81
+ console.print(f"\n[green]WORDLE {len(already_guessed)}/{ALLOWED_GUESSES}[/]\n")
 
82
  console.print(*full_wordle_pattern, sep="\n")
83
 
84
 
85
+ if __name__ == "__main__":
86
  console = Console()
87
  chosen_word = choice(target_vocabulary)
88
  console.print(WELCOME_MESSAGE)