Spaces:
Sleeping
Sleeping
Fix code style with black and isort
Browse files- a3c/eval.py +3 -4
- a3c/net.py +5 -5
- a3c/play.py +6 -8
- a3c/shared_adam.py +8 -10
- a3c/train.py +40 -14
- a3c/utils.py +1 -1
- a3c/worker.py +52 -43
- api_rest/api.py +17 -15
- main.py +41 -41
- rs_wordle_player/firebase_connector.py +20 -19
- rs_wordle_player/rs_wordle_player.py +3 -2
- rs_wordle_player/selenium_player.py +14 -15
- wordle_env/__init__.py +4 -8
- wordle_env/const.py +1 -1
- wordle_env/state.py +16 -25
- wordle_env/test_wordle.py +1 -2
- wordle_env/wordle.py +25 -26
- wordle_env/words.py +9 -5
- wordle_game.py +26 -25
a3c/eval.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import os
|
|
|
2 |
import torch
|
3 |
|
4 |
from .net import GreedyNet
|
@@ -14,9 +15,7 @@ def evaluate_checkpoints(dir, env):
|
|
14 |
wins, guesses = evaluate(env, pretrained_model_path)
|
15 |
results[checkpoint] = wins, guesses
|
16 |
return dict(
|
17 |
-
sorted(results.items(), key=lambda x: (
|
18 |
-
x[1][0], -x[1][1]), reverse=True
|
19 |
-
)
|
20 |
)
|
21 |
|
22 |
|
@@ -39,4 +38,4 @@ def evaluate(env, pretrained_model_path):
|
|
39 |
took {n_win_guesses/n_wins} guesses per win, "
|
40 |
f"{n_guesses / N} including losses."
|
41 |
)
|
42 |
-
return n_wins/N*100, n_win_guesses/n_wins
|
|
|
1 |
import os
|
2 |
+
|
3 |
import torch
|
4 |
|
5 |
from .net import GreedyNet
|
|
|
15 |
wins, guesses = evaluate(env, pretrained_model_path)
|
16 |
results[checkpoint] = wins, guesses
|
17 |
return dict(
|
18 |
+
sorted(results.items(), key=lambda x: (x[1][0], -x[1][1]), reverse=True)
|
|
|
|
|
19 |
)
|
20 |
|
21 |
|
|
|
38 |
took {n_win_guesses/n_wins} guesses per win, "
|
39 |
f"{n_guesses / N} including losses."
|
40 |
)
|
41 |
+
return n_wins / N * 100, n_win_guesses / n_wins
|
a3c/net.py
CHANGED
@@ -1,7 +1,7 @@
|
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
4 |
-
import numpy as np
|
5 |
|
6 |
|
7 |
class Net(nn.Module):
|
@@ -23,15 +23,15 @@ class Net(nn.Module):
|
|
23 |
word_array = np.zeros((word_width, len(word_list)))
|
24 |
for i, word in enumerate(word_list):
|
25 |
for j, c in enumerate(word):
|
26 |
-
word_array[j*26 + (ord(c) - ord(
|
27 |
self.words = torch.Tensor(word_array)
|
28 |
|
29 |
def forward(self, x):
|
30 |
values = self.v1(x.float())
|
31 |
logits = torch.log_softmax(
|
32 |
-
torch.tensordot(self.actor_head(values), self.words,
|
33 |
-
|
34 |
-
|
35 |
values = self.v4(values)
|
36 |
return logits, values
|
37 |
|
|
|
1 |
+
import numpy as np
|
2 |
import torch
|
3 |
import torch.nn as nn
|
4 |
import torch.nn.functional as F
|
|
|
5 |
|
6 |
|
7 |
class Net(nn.Module):
|
|
|
23 |
word_array = np.zeros((word_width, len(word_list)))
|
24 |
for i, word in enumerate(word_list):
|
25 |
for j, c in enumerate(word):
|
26 |
+
word_array[j * 26 + (ord(c) - ord("A")), i] = 1
|
27 |
self.words = torch.Tensor(word_array)
|
28 |
|
29 |
def forward(self, x):
|
30 |
values = self.v1(x.float())
|
31 |
logits = torch.log_softmax(
|
32 |
+
torch.tensordot(self.actor_head(values), self.words, dims=((1,), (0,))),
|
33 |
+
dim=-1,
|
34 |
+
)
|
35 |
values = self.v4(values)
|
36 |
return logits, values
|
37 |
|
a3c/play.py
CHANGED
@@ -1,15 +1,18 @@
|
|
1 |
import os
|
|
|
2 |
import torch
|
3 |
from dotenv import load_dotenv
|
|
|
4 |
from wordle_env.state import update_from_mask
|
|
|
5 |
from .net import GreedyNet
|
6 |
from .utils import v_wrap
|
7 |
|
8 |
|
9 |
def get_play_model_path():
|
10 |
load_dotenv()
|
11 |
-
model_name = os.getenv(
|
12 |
-
model_checkpoint_dir = os.path.join(
|
13 |
return os.path.join(model_checkpoint_dir, model_name)
|
14 |
|
15 |
|
@@ -28,12 +31,7 @@ def get_initial_state(env):
|
|
28 |
return state
|
29 |
|
30 |
|
31 |
-
def suggest(
|
32 |
-
env,
|
33 |
-
words,
|
34 |
-
states,
|
35 |
-
pretrained_model_path
|
36 |
-
) -> str:
|
37 |
"""
|
38 |
Given a list of words and masks, return the next suggested word
|
39 |
|
|
|
1 |
import os
|
2 |
+
|
3 |
import torch
|
4 |
from dotenv import load_dotenv
|
5 |
+
|
6 |
from wordle_env.state import update_from_mask
|
7 |
+
|
8 |
from .net import GreedyNet
|
9 |
from .utils import v_wrap
|
10 |
|
11 |
|
12 |
def get_play_model_path():
|
13 |
load_dotenv()
|
14 |
+
model_name = os.getenv("RS_WORDLE_MODEL_NAME")
|
15 |
+
model_checkpoint_dir = os.path.join("checkpoints", "best_models")
|
16 |
return os.path.join(model_checkpoint_dir, model_name)
|
17 |
|
18 |
|
|
|
31 |
return state
|
32 |
|
33 |
|
34 |
+
def suggest(env, words, states, pretrained_model_path) -> str:
|
|
|
|
|
|
|
|
|
|
|
35 |
"""
|
36 |
Given a list of words and masks, return the next suggested word
|
37 |
|
a3c/shared_adam.py
CHANGED
@@ -6,20 +6,18 @@ import torch
|
|
6 |
|
7 |
|
8 |
class SharedAdam(torch.optim.Adam):
|
9 |
-
def __init__(self, params, lr=1e-3, betas=(0.9, 0.99), eps=1e-8,
|
10 |
-
weight_decay=0):
|
11 |
super(SharedAdam, self).__init__(
|
12 |
-
params, lr=lr,
|
13 |
-
betas=betas, eps=eps, weight_decay=weight_decay
|
14 |
)
|
15 |
# State initialization
|
16 |
for group in self.param_groups:
|
17 |
-
for p in group[
|
18 |
state = self.state[p]
|
19 |
-
state[
|
20 |
-
state[
|
21 |
-
state[
|
22 |
|
23 |
# share in memory
|
24 |
-
state[
|
25 |
-
state[
|
|
|
6 |
|
7 |
|
8 |
class SharedAdam(torch.optim.Adam):
|
9 |
+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.99), eps=1e-8, weight_decay=0):
|
|
|
10 |
super(SharedAdam, self).__init__(
|
11 |
+
params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay
|
|
|
12 |
)
|
13 |
# State initialization
|
14 |
for group in self.param_groups:
|
15 |
+
for p in group["params"]:
|
16 |
state = self.state[p]
|
17 |
+
state["step"] = 0
|
18 |
+
state["exp_avg"] = torch.zeros_like(p.data)
|
19 |
+
state["exp_avg_sq"] = torch.zeros_like(p.data)
|
20 |
|
21 |
# share in memory
|
22 |
+
state["exp_avg"].share_memory_()
|
23 |
+
state["exp_avg_sq"].share_memory_()
|
a3c/train.py
CHANGED
@@ -1,10 +1,12 @@
|
|
1 |
import os
|
2 |
-
import numpy as np
|
3 |
import random
|
|
|
|
|
4 |
import torch
|
5 |
import torch.multiprocessing as mp
|
6 |
-
|
7 |
from .net import Net
|
|
|
8 |
from .worker import Worker
|
9 |
|
10 |
|
@@ -25,12 +27,12 @@ def train(
|
|
25 |
env,
|
26 |
max_ep,
|
27 |
model_checkpoint_dir,
|
28 |
-
gamma=0
|
29 |
seed=100,
|
30 |
pretrained_model_path=None,
|
31 |
save=False,
|
32 |
min_reward=9.9,
|
33 |
-
every_n_save=100
|
34 |
):
|
35 |
os.environ["OMP_NUM_THREADS"] = "1"
|
36 |
if not os.path.exists(model_checkpoint_dir):
|
@@ -45,18 +47,40 @@ def train(
|
|
45 |
if pretrained_model_path:
|
46 |
gnet.load_state_dict(torch.load(pretrained_model_path))
|
47 |
gnet.share_memory() # share the global parameters in multiprocessing
|
48 |
-
opt = SharedAdam(
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
# parallel training
|
54 |
workers = [
|
55 |
Worker(
|
56 |
-
max_ep,
|
57 |
-
|
58 |
-
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
]
|
61 |
[w.start() for w in workers]
|
62 |
res = [] # record episode reward to plot
|
@@ -68,6 +92,8 @@ def train(
|
|
68 |
break
|
69 |
[w.join() for w in workers]
|
70 |
if save:
|
71 |
-
torch.save(
|
72 |
-
|
|
|
|
|
73 |
return global_ep, win_ep, gnet, res
|
|
|
1 |
import os
|
|
|
2 |
import random
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
import torch
|
6 |
import torch.multiprocessing as mp
|
7 |
+
|
8 |
from .net import Net
|
9 |
+
from .shared_adam import SharedAdam
|
10 |
from .worker import Worker
|
11 |
|
12 |
|
|
|
27 |
env,
|
28 |
max_ep,
|
29 |
model_checkpoint_dir,
|
30 |
+
gamma=0.0,
|
31 |
seed=100,
|
32 |
pretrained_model_path=None,
|
33 |
save=False,
|
34 |
min_reward=9.9,
|
35 |
+
every_n_save=100,
|
36 |
):
|
37 |
os.environ["OMP_NUM_THREADS"] = "1"
|
38 |
if not os.path.exists(model_checkpoint_dir):
|
|
|
47 |
if pretrained_model_path:
|
48 |
gnet.load_state_dict(torch.load(pretrained_model_path))
|
49 |
gnet.share_memory() # share the global parameters in multiprocessing
|
50 |
+
opt = SharedAdam(
|
51 |
+
gnet.parameters(), lr=1e-4, betas=(0.92, 0.999)
|
52 |
+
) # global optimizer
|
53 |
+
global_ep, global_ep_r, res_queue, win_ep = (
|
54 |
+
mp.Value("i", 0),
|
55 |
+
mp.Value("d", 0.0),
|
56 |
+
mp.Queue(),
|
57 |
+
mp.Value("i", 0),
|
58 |
+
)
|
59 |
|
60 |
# parallel training
|
61 |
workers = [
|
62 |
Worker(
|
63 |
+
max_ep,
|
64 |
+
gnet,
|
65 |
+
opt,
|
66 |
+
global_ep,
|
67 |
+
global_ep_r,
|
68 |
+
res_queue,
|
69 |
+
i,
|
70 |
+
env,
|
71 |
+
n_s,
|
72 |
+
n_a,
|
73 |
+
words_list,
|
74 |
+
word_width,
|
75 |
+
win_ep,
|
76 |
+
model_checkpoint_dir,
|
77 |
+
gamma,
|
78 |
+
pretrained_model_path,
|
79 |
+
save,
|
80 |
+
min_reward,
|
81 |
+
every_n_save,
|
82 |
+
)
|
83 |
+
for i in range(mp.cpu_count())
|
84 |
]
|
85 |
[w.start() for w in workers]
|
86 |
res = [] # record episode reward to plot
|
|
|
92 |
break
|
93 |
[w.join() for w in workers]
|
94 |
if save:
|
95 |
+
torch.save(
|
96 |
+
gnet.state_dict(),
|
97 |
+
os.path.join(model_checkpoint_dir, f"model_{env.unwrapped.spec.id}.pth"),
|
98 |
+
)
|
99 |
return global_ep, win_ep, gnet, res
|
a3c/utils.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
-
import torch
|
2 |
import numpy as np
|
|
|
3 |
|
4 |
|
5 |
def v_wrap(np_array, dtype=np.float32):
|
|
|
|
|
1 |
import numpy as np
|
2 |
+
import torch
|
3 |
|
4 |
|
5 |
def v_wrap(np_array, dtype=np.float32):
|
a3c/worker.py
CHANGED
@@ -2,40 +2,42 @@
|
|
2 |
Worker class implementation of the a3c discrete algorithm
|
3 |
"""
|
4 |
import os
|
5 |
-
|
6 |
import numpy as np
|
|
|
7 |
import torch.multiprocessing as mp
|
8 |
from torch import nn
|
|
|
9 |
from .net import Net
|
10 |
from .utils import v_wrap
|
11 |
|
12 |
|
13 |
class Worker(mp.Process):
|
14 |
def __init__(
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
):
|
36 |
super(Worker, self).__init__()
|
37 |
self.max_ep = max_ep
|
38 |
-
self.name =
|
39 |
self.g_ep = global_ep
|
40 |
self.g_ep_r = global_ep_r
|
41 |
self.res_queue = res_queue
|
@@ -57,7 +59,7 @@ class Worker(mp.Process):
|
|
57 |
while self.g_ep.value < self.max_ep:
|
58 |
s = self.env.reset()
|
59 |
buffer_s, buffer_a, buffer_r = [], [], []
|
60 |
-
ep_r = 0.
|
61 |
while True:
|
62 |
a = self.lnet.choose_action(v_wrap(s[None, :]))
|
63 |
s_, r, done, _ = self.env.step(a)
|
@@ -68,11 +70,9 @@ class Worker(mp.Process):
|
|
68 |
|
69 |
if done: # update global and assign to local net
|
70 |
# sync
|
71 |
-
self.push_and_pull(done, s_, buffer_s,
|
72 |
-
buffer_a, buffer_r)
|
73 |
goal_word = self.word_list[self.env.goal_word]
|
74 |
-
self.record(ep_r, goal_word,
|
75 |
-
self.word_list[a], len(buffer_a))
|
76 |
self.save_model()
|
77 |
buffer_s, buffer_a, buffer_r = [], [], []
|
78 |
break
|
@@ -81,22 +81,22 @@ class Worker(mp.Process):
|
|
81 |
|
82 |
def push_and_pull(self, done, s_, bs, ba, br):
|
83 |
if done:
|
84 |
-
v_s_ = 0.
|
85 |
else:
|
86 |
-
v_s_ = self.lnet.forward(v_wrap(
|
87 |
-
s_[None, :]))[-1].data.numpy()[0, 0]
|
88 |
|
89 |
buffer_v_target = []
|
90 |
-
for r in br[::-1]:
|
91 |
v_s_ = r + self.gamma * v_s_
|
92 |
buffer_v_target.append(v_s_)
|
93 |
buffer_v_target.reverse()
|
94 |
|
95 |
loss = self.lnet.loss_func(
|
96 |
v_wrap(np.vstack(bs)),
|
97 |
-
v_wrap(np.array(ba), dtype=np.int64)
|
98 |
-
ba[0].dtype == np.int64
|
99 |
-
v_wrap(np.
|
|
|
100 |
)
|
101 |
|
102 |
# calculate local gradients and push local parameters to global
|
@@ -110,16 +110,21 @@ class Worker(mp.Process):
|
|
110 |
self.lnet.load_state_dict(self.gnet.state_dict())
|
111 |
|
112 |
def save_model(self):
|
113 |
-
if (
|
114 |
-
|
115 |
-
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
def record(self, ep_r, goal_word, action, action_number):
|
119 |
with self.g_ep.get_lock():
|
120 |
self.g_ep.value += 1
|
121 |
with self.g_ep_r.get_lock():
|
122 |
-
if self.g_ep_r.value == 0
|
123 |
self.g_ep_r.value = ep_r
|
124 |
else:
|
125 |
self.g_ep_r.value = self.g_ep_r.value * 0.99 + ep_r * 0.01
|
@@ -129,9 +134,13 @@ class Worker(mp.Process):
|
|
129 |
if self.g_ep.value % 100 == 0:
|
130 |
print(
|
131 |
self.name,
|
132 |
-
"Ep:",
|
|
|
133 |
"| Ep_r: %.0f" % self.g_ep_r.value,
|
134 |
-
"| Goal :",
|
135 |
-
|
136 |
-
"|
|
|
|
|
|
|
|
137 |
)
|
|
|
2 |
Worker class implementation of the a3c discrete algorithm
|
3 |
"""
|
4 |
import os
|
5 |
+
|
6 |
import numpy as np
|
7 |
+
import torch
|
8 |
import torch.multiprocessing as mp
|
9 |
from torch import nn
|
10 |
+
|
11 |
from .net import Net
|
12 |
from .utils import v_wrap
|
13 |
|
14 |
|
15 |
class Worker(mp.Process):
|
16 |
def __init__(
|
17 |
+
self,
|
18 |
+
max_ep,
|
19 |
+
gnet,
|
20 |
+
opt,
|
21 |
+
global_ep,
|
22 |
+
global_ep_r,
|
23 |
+
res_queue,
|
24 |
+
name,
|
25 |
+
env,
|
26 |
+
N_S,
|
27 |
+
N_A,
|
28 |
+
words_list,
|
29 |
+
word_width,
|
30 |
+
winning_ep,
|
31 |
+
model_checkpoint_dir,
|
32 |
+
gamma=0.0,
|
33 |
+
pretrained_model_path=None,
|
34 |
+
save=False,
|
35 |
+
min_reward=9.9,
|
36 |
+
every_n_save=100,
|
37 |
):
|
38 |
super(Worker, self).__init__()
|
39 |
self.max_ep = max_ep
|
40 |
+
self.name = "w%02i" % name
|
41 |
self.g_ep = global_ep
|
42 |
self.g_ep_r = global_ep_r
|
43 |
self.res_queue = res_queue
|
|
|
59 |
while self.g_ep.value < self.max_ep:
|
60 |
s = self.env.reset()
|
61 |
buffer_s, buffer_a, buffer_r = [], [], []
|
62 |
+
ep_r = 0.0
|
63 |
while True:
|
64 |
a = self.lnet.choose_action(v_wrap(s[None, :]))
|
65 |
s_, r, done, _ = self.env.step(a)
|
|
|
70 |
|
71 |
if done: # update global and assign to local net
|
72 |
# sync
|
73 |
+
self.push_and_pull(done, s_, buffer_s, buffer_a, buffer_r)
|
|
|
74 |
goal_word = self.word_list[self.env.goal_word]
|
75 |
+
self.record(ep_r, goal_word, self.word_list[a], len(buffer_a))
|
|
|
76 |
self.save_model()
|
77 |
buffer_s, buffer_a, buffer_r = [], [], []
|
78 |
break
|
|
|
81 |
|
82 |
def push_and_pull(self, done, s_, bs, ba, br):
|
83 |
if done:
|
84 |
+
v_s_ = 0.0 # terminal
|
85 |
else:
|
86 |
+
v_s_ = self.lnet.forward(v_wrap(s_[None, :]))[-1].data.numpy()[0, 0]
|
|
|
87 |
|
88 |
buffer_v_target = []
|
89 |
+
for r in br[::-1]: # reverse buffer r
|
90 |
v_s_ = r + self.gamma * v_s_
|
91 |
buffer_v_target.append(v_s_)
|
92 |
buffer_v_target.reverse()
|
93 |
|
94 |
loss = self.lnet.loss_func(
|
95 |
v_wrap(np.vstack(bs)),
|
96 |
+
v_wrap(np.array(ba), dtype=np.int64)
|
97 |
+
if ba[0].dtype == np.int64
|
98 |
+
else v_wrap(np.vstack(ba)),
|
99 |
+
v_wrap(np.array(buffer_v_target)[:, None]),
|
100 |
)
|
101 |
|
102 |
# calculate local gradients and push local parameters to global
|
|
|
110 |
self.lnet.load_state_dict(self.gnet.state_dict())
|
111 |
|
112 |
def save_model(self):
|
113 |
+
if (
|
114 |
+
self.save
|
115 |
+
and self.g_ep_r.value >= self.min_reward
|
116 |
+
and self.g_ep.value % self.every_n_save == 0
|
117 |
+
):
|
118 |
+
torch.save(
|
119 |
+
self.gnet.state_dict(),
|
120 |
+
os.path.join(self.model_checkpoint_dir, f"model_{self.g_ep.value}.pth"),
|
121 |
+
)
|
122 |
|
123 |
def record(self, ep_r, goal_word, action, action_number):
|
124 |
with self.g_ep.get_lock():
|
125 |
self.g_ep.value += 1
|
126 |
with self.g_ep_r.get_lock():
|
127 |
+
if self.g_ep_r.value == 0.0:
|
128 |
self.g_ep_r.value = ep_r
|
129 |
else:
|
130 |
self.g_ep_r.value = self.g_ep_r.value * 0.99 + ep_r * 0.01
|
|
|
134 |
if self.g_ep.value % 100 == 0:
|
135 |
print(
|
136 |
self.name,
|
137 |
+
"Ep:",
|
138 |
+
self.g_ep.value,
|
139 |
"| Ep_r: %.0f" % self.g_ep_r.value,
|
140 |
+
"| Goal :",
|
141 |
+
goal_word,
|
142 |
+
"| Action: ",
|
143 |
+
action,
|
144 |
+
"| Actions: ",
|
145 |
+
action_number,
|
146 |
)
|
api_rest/api.py
CHANGED
@@ -1,30 +1,32 @@
|
|
1 |
import random
|
2 |
-
|
3 |
-
from flask import Flask,
|
4 |
from flask_cors import cross_origin
|
5 |
-
|
|
|
6 |
from wordle_env.wordle import get_env
|
|
|
7 |
|
8 |
app = Flask(__name__)
|
9 |
|
10 |
|
11 |
def validate_goal_word(word):
|
12 |
if not word:
|
13 |
-
return True,
|
14 |
if word.upper() not in target_vocabulary:
|
15 |
-
return True,
|
16 |
-
return False,
|
17 |
|
18 |
|
19 |
-
@app.route(
|
20 |
-
@cross_origin(origin=
|
21 |
def get_play():
|
22 |
# Get the goal word from the request
|
23 |
-
word = request.args.get(
|
24 |
|
25 |
error, msge = validate_goal_word(word)
|
26 |
if error:
|
27 |
-
return jsonify({
|
28 |
|
29 |
word = word.upper()
|
30 |
env = get_env()
|
@@ -32,16 +34,16 @@ def get_play():
|
|
32 |
# Call the play function with the goal word
|
33 |
# and return the attempts and the result
|
34 |
won, attempts = play(env, model_path, word)
|
35 |
-
return jsonify({
|
36 |
|
37 |
|
38 |
-
@app.route(
|
39 |
-
@cross_origin(origin=
|
40 |
def get_word():
|
41 |
# Get a random word from the target vocabulary used to train the model
|
42 |
word = random.choice(target_vocabulary)
|
43 |
word = word.upper()
|
44 |
-
return jsonify({
|
45 |
|
46 |
|
47 |
def create_app(settings_override=None):
|
@@ -58,5 +60,5 @@ def create_app(settings_override=None):
|
|
58 |
return app
|
59 |
|
60 |
|
61 |
-
if __name__ ==
|
62 |
app.run(debug=True)
|
|
|
1 |
import random
|
2 |
+
|
3 |
+
from flask import Flask, jsonify, request
|
4 |
from flask_cors import cross_origin
|
5 |
+
|
6 |
+
from a3c.play import get_play_model_path, play
|
7 |
from wordle_env.wordle import get_env
|
8 |
+
from wordle_env.words import target_vocabulary
|
9 |
|
10 |
app = Flask(__name__)
|
11 |
|
12 |
|
13 |
def validate_goal_word(word):
|
14 |
if not word:
|
15 |
+
return True, "Goal word not provided"
|
16 |
if word.upper() not in target_vocabulary:
|
17 |
+
return True, "Goal word not in vocabulary"
|
18 |
+
return False, ""
|
19 |
|
20 |
|
21 |
+
@app.route("/play_word", methods=["GET"])
|
22 |
+
@cross_origin(origin="*", headers=["Content-Type", "Authorization"])
|
23 |
def get_play():
|
24 |
# Get the goal word from the request
|
25 |
+
word = request.args.get("goal_word")
|
26 |
|
27 |
error, msge = validate_goal_word(word)
|
28 |
if error:
|
29 |
+
return jsonify({"error": msge}), 400
|
30 |
|
31 |
word = word.upper()
|
32 |
env = get_env()
|
|
|
34 |
# Call the play function with the goal word
|
35 |
# and return the attempts and the result
|
36 |
won, attempts = play(env, model_path, word)
|
37 |
+
return jsonify({"attempts": attempts, "won": won})
|
38 |
|
39 |
|
40 |
+
@app.route("/word", methods=["GET"])
|
41 |
+
@cross_origin(origin="*", headers=["Content-Type", "Authorization"])
|
42 |
def get_word():
|
43 |
# Get a random word from the target vocabulary used to train the model
|
44 |
word = random.choice(target_vocabulary)
|
45 |
word = word.upper()
|
46 |
+
return jsonify({"word": word})
|
47 |
|
48 |
|
49 |
def create_app(settings_override=None):
|
|
|
60 |
return app
|
61 |
|
62 |
|
63 |
+
if __name__ == "__main__":
|
64 |
app.run(debug=True)
|
main.py
CHANGED
@@ -3,23 +3,33 @@
|
|
3 |
import argparse
|
4 |
import os
|
5 |
import time
|
|
|
6 |
import matplotlib.pyplot as plt
|
7 |
-
|
8 |
from a3c.eval import evaluate, evaluate_checkpoints
|
9 |
from a3c.play import suggest
|
|
|
10 |
from wordle_env.wordle import get_env
|
11 |
|
12 |
|
13 |
def training_mode(args, env, model_checkpoint_dir):
|
14 |
max_ep = args.games
|
15 |
start_time = time.time()
|
16 |
-
pretrained_model_path =
|
17 |
-
model_checkpoint_dir, args.model_name
|
18 |
-
|
|
|
|
|
19 |
global_ep, win_ep, gnet, res = train(
|
20 |
-
env,
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
)
|
24 |
print("--- %.0f seconds ---" % (time.time() - start_time))
|
25 |
print_results(global_ep, win_ep, res)
|
@@ -34,8 +44,8 @@ def evaluation_mode(args, env, model_checkpoint_dir):
|
|
34 |
|
35 |
def play_mode(args, env, model_checkpoint_dir):
|
36 |
print("Play mode")
|
37 |
-
words = [word.strip() for word in args.words.split(
|
38 |
-
states = [state.strip() for state in args.states.split(
|
39 |
pretrained_model_path = os.path.join(model_checkpoint_dir, args.model_name)
|
40 |
word = suggest(env, words, states, pretrained_model_path)
|
41 |
print(word)
|
@@ -45,8 +55,8 @@ def print_results(global_ep, win_ep, res):
|
|
45 |
print("Jugadas:", global_ep.value)
|
46 |
print("Ganadas:", win_ep.value)
|
47 |
plt.plot(res)
|
48 |
-
plt.ylabel(
|
49 |
-
plt.xlabel(
|
50 |
plt.show()
|
51 |
|
52 |
|
@@ -55,90 +65,80 @@ if __name__ == "__main__":
|
|
55 |
parser.add_argument(
|
56 |
"enviroment",
|
57 |
help="Enviroment (type of wordle game) used for training, \
|
58 |
-
example: WordleEnvFull-v0"
|
59 |
)
|
60 |
parser.add_argument(
|
61 |
"--models_dir",
|
62 |
help="Directory where models are saved (default=checkpoints)",
|
63 |
-
default=
|
64 |
)
|
65 |
-
subparsers = parser.add_subparsers(help=
|
66 |
|
67 |
parser_train = subparsers.add_parser(
|
68 |
-
|
69 |
-
help='Train a model from scratch or train from pretrained model'
|
70 |
)
|
71 |
parser_train.add_argument(
|
72 |
-
"--games",
|
73 |
-
"-g",
|
74 |
-
help="Number of games to train",
|
75 |
-
type=int,
|
76 |
-
required=True
|
77 |
)
|
78 |
parser_train.add_argument(
|
79 |
"--model_name",
|
80 |
"-m",
|
81 |
help="If want to train from a pretrained model, \
|
82 |
-
the name of the pretrained model file"
|
83 |
)
|
84 |
parser_train.add_argument(
|
85 |
"--gamma",
|
86 |
help="Gamma hyperparameter (discount factor) value",
|
87 |
type=float,
|
88 |
-
default=0.
|
89 |
)
|
90 |
parser_train.add_argument(
|
91 |
-
"--seed",
|
92 |
-
help="Seed used for random numbers generation",
|
93 |
-
type=int,
|
94 |
-
default=100
|
95 |
)
|
96 |
parser_train.add_argument(
|
97 |
"--save",
|
98 |
-
|
99 |
help="Save instances of the model while training",
|
100 |
-
action=
|
101 |
)
|
102 |
parser_train.add_argument(
|
103 |
"--min_reward",
|
104 |
help="The minimun global reward value achieved for saving the model",
|
105 |
type=float,
|
106 |
-
default=9.9
|
107 |
)
|
108 |
parser_train.add_argument(
|
109 |
"--every_n_save",
|
110 |
help="Check every n training steps to save the model",
|
111 |
type=int,
|
112 |
-
default=100
|
113 |
)
|
114 |
parser_train.set_defaults(func=training_mode)
|
115 |
|
116 |
parser_eval = subparsers.add_parser(
|
117 |
-
|
|
|
118 |
parser_eval.set_defaults(func=evaluation_mode)
|
119 |
|
120 |
parser_play = subparsers.add_parser(
|
121 |
-
|
122 |
-
help=
|
123 |
-
and the model will try to predict the goal word
|
124 |
)
|
125 |
parser_play.add_argument(
|
126 |
-
"--words",
|
127 |
-
"-w",
|
128 |
-
help="List of words played in the wordle game",
|
129 |
-
required=True
|
130 |
)
|
131 |
parser_play.add_argument(
|
132 |
"--states",
|
133 |
"-st",
|
134 |
help="List of states returned by playing each of the words",
|
135 |
-
required=True
|
136 |
)
|
137 |
parser_play.add_argument(
|
138 |
"--model_name",
|
139 |
"-m",
|
140 |
help="Name of the pretrained model file thich will play the game",
|
141 |
-
required=True
|
142 |
)
|
143 |
parser_play.set_defaults(func=play_mode)
|
144 |
|
|
|
3 |
import argparse
|
4 |
import os
|
5 |
import time
|
6 |
+
|
7 |
import matplotlib.pyplot as plt
|
8 |
+
|
9 |
from a3c.eval import evaluate, evaluate_checkpoints
|
10 |
from a3c.play import suggest
|
11 |
+
from a3c.train import train
|
12 |
from wordle_env.wordle import get_env
|
13 |
|
14 |
|
15 |
def training_mode(args, env, model_checkpoint_dir):
|
16 |
max_ep = args.games
|
17 |
start_time = time.time()
|
18 |
+
pretrained_model_path = (
|
19 |
+
os.path.join(model_checkpoint_dir, args.model_name)
|
20 |
+
if args.model_name
|
21 |
+
else args.model_name
|
22 |
+
)
|
23 |
global_ep, win_ep, gnet, res = train(
|
24 |
+
env,
|
25 |
+
max_ep,
|
26 |
+
model_checkpoint_dir,
|
27 |
+
args.gamma,
|
28 |
+
args.seed,
|
29 |
+
pretrained_model_path,
|
30 |
+
args.save,
|
31 |
+
args.min_reward,
|
32 |
+
args.every_n_save,
|
33 |
)
|
34 |
print("--- %.0f seconds ---" % (time.time() - start_time))
|
35 |
print_results(global_ep, win_ep, res)
|
|
|
44 |
|
45 |
def play_mode(args, env, model_checkpoint_dir):
|
46 |
print("Play mode")
|
47 |
+
words = [word.strip() for word in args.words.split(",")]
|
48 |
+
states = [state.strip() for state in args.states.split(",")]
|
49 |
pretrained_model_path = os.path.join(model_checkpoint_dir, args.model_name)
|
50 |
word = suggest(env, words, states, pretrained_model_path)
|
51 |
print(word)
|
|
|
55 |
print("Jugadas:", global_ep.value)
|
56 |
print("Ganadas:", win_ep.value)
|
57 |
plt.plot(res)
|
58 |
+
plt.ylabel("Moving average ep reward")
|
59 |
+
plt.xlabel("Step")
|
60 |
plt.show()
|
61 |
|
62 |
|
|
|
65 |
parser.add_argument(
|
66 |
"enviroment",
|
67 |
help="Enviroment (type of wordle game) used for training, \
|
68 |
+
example: WordleEnvFull-v0",
|
69 |
)
|
70 |
parser.add_argument(
|
71 |
"--models_dir",
|
72 |
help="Directory where models are saved (default=checkpoints)",
|
73 |
+
default="checkpoints",
|
74 |
)
|
75 |
+
subparsers = parser.add_subparsers(help="sub-command help")
|
76 |
|
77 |
parser_train = subparsers.add_parser(
|
78 |
+
"train", help="Train a model from scratch or train from pretrained model"
|
|
|
79 |
)
|
80 |
parser_train.add_argument(
|
81 |
+
"--games", "-g", help="Number of games to train", type=int, required=True
|
|
|
|
|
|
|
|
|
82 |
)
|
83 |
parser_train.add_argument(
|
84 |
"--model_name",
|
85 |
"-m",
|
86 |
help="If want to train from a pretrained model, \
|
87 |
+
the name of the pretrained model file",
|
88 |
)
|
89 |
parser_train.add_argument(
|
90 |
"--gamma",
|
91 |
help="Gamma hyperparameter (discount factor) value",
|
92 |
type=float,
|
93 |
+
default=0.0,
|
94 |
)
|
95 |
parser_train.add_argument(
|
96 |
+
"--seed", help="Seed used for random numbers generation", type=int, default=100
|
|
|
|
|
|
|
97 |
)
|
98 |
parser_train.add_argument(
|
99 |
"--save",
|
100 |
+
"-s",
|
101 |
help="Save instances of the model while training",
|
102 |
+
action="store_true",
|
103 |
)
|
104 |
parser_train.add_argument(
|
105 |
"--min_reward",
|
106 |
help="The minimun global reward value achieved for saving the model",
|
107 |
type=float,
|
108 |
+
default=9.9,
|
109 |
)
|
110 |
parser_train.add_argument(
|
111 |
"--every_n_save",
|
112 |
help="Check every n training steps to save the model",
|
113 |
type=int,
|
114 |
+
default=100,
|
115 |
)
|
116 |
parser_train.set_defaults(func=training_mode)
|
117 |
|
118 |
parser_eval = subparsers.add_parser(
|
119 |
+
"eval", help="Evaluate saved models for the enviroment"
|
120 |
+
)
|
121 |
parser_eval.set_defaults(func=evaluation_mode)
|
122 |
|
123 |
parser_play = subparsers.add_parser(
|
124 |
+
"play",
|
125 |
+
help="Give the model a word and the state result \
|
126 |
+
and the model will try to predict the goal word",
|
127 |
)
|
128 |
parser_play.add_argument(
|
129 |
+
"--words", "-w", help="List of words played in the wordle game", required=True
|
|
|
|
|
|
|
130 |
)
|
131 |
parser_play.add_argument(
|
132 |
"--states",
|
133 |
"-st",
|
134 |
help="List of states returned by playing each of the words",
|
135 |
+
required=True,
|
136 |
)
|
137 |
parser_play.add_argument(
|
138 |
"--model_name",
|
139 |
"-m",
|
140 |
help="Name of the pretrained model file thich will play the game",
|
141 |
+
required=True,
|
142 |
)
|
143 |
parser_play.set_defaults(func=play_mode)
|
144 |
|
rs_wordle_player/firebase_connector.py
CHANGED
@@ -1,13 +1,12 @@
|
|
1 |
import os
|
2 |
-
import firebase_admin
|
3 |
-
from firebase_admin import credentials
|
4 |
-
from firebase_admin import firestore
|
5 |
from datetime import datetime
|
6 |
-
from dotenv import load_dotenv
|
7 |
|
|
|
|
|
|
|
8 |
|
9 |
-
class FirebaseConnector():
|
10 |
|
|
|
11 |
def __init__(self):
|
12 |
load_dotenv()
|
13 |
cert_path = self.get_credentials_path()
|
@@ -20,32 +19,34 @@ class FirebaseConnector():
|
|
20 |
return db
|
21 |
|
22 |
def get_credentials_path(self):
|
23 |
-
credentials_path = os.getenv(
|
24 |
return credentials_path
|
25 |
|
26 |
def get_user(self):
|
27 |
-
user = os.getenv(
|
28 |
return user
|
29 |
|
30 |
def get_state_from_fb_result(self, firebase_result):
|
31 |
-
result_number_map = {
|
32 |
-
'misplaced': '1',
|
33 |
-
'correct': '2'}
|
34 |
char_result_map = map(
|
35 |
lambda char_res: result_number_map[char_res], firebase_result
|
36 |
)
|
37 |
-
return
|
38 |
|
39 |
def today(self):
|
40 |
-
return datetime.today().strftime(
|
41 |
|
42 |
def today_user_results(self):
|
43 |
-
daily_results_col =
|
44 |
currentUser = self.get_user()
|
45 |
# Execute the query and get the first result
|
46 |
-
docs =
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
49 |
return docs
|
50 |
|
51 |
def today_user_attempts(self):
|
@@ -53,10 +54,10 @@ class FirebaseConnector():
|
|
53 |
attempted_words = []
|
54 |
if len(docs) > 0:
|
55 |
doc = docs[0]
|
56 |
-
attempted_words = doc.to_dict().get(
|
57 |
return attempted_words
|
58 |
|
59 |
def today_word(self):
|
60 |
-
words_col =
|
61 |
doc = self.db.collection(words_col).document(self.today())
|
62 |
-
return doc.get().get(
|
|
|
1 |
import os
|
|
|
|
|
|
|
2 |
from datetime import datetime
|
|
|
3 |
|
4 |
+
import firebase_admin
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
from firebase_admin import credentials, firestore
|
7 |
|
|
|
8 |
|
9 |
+
class FirebaseConnector:
|
10 |
def __init__(self):
|
11 |
load_dotenv()
|
12 |
cert_path = self.get_credentials_path()
|
|
|
19 |
return db
|
20 |
|
21 |
def get_credentials_path(self):
|
22 |
+
credentials_path = os.getenv("RS_FIREBASE_CREDENTIALS_PATH")
|
23 |
return credentials_path
|
24 |
|
25 |
def get_user(self):
|
26 |
+
user = os.getenv("RS_WORDLE_USER")
|
27 |
return user
|
28 |
|
29 |
def get_state_from_fb_result(self, firebase_result):
|
30 |
+
result_number_map = {"incorrect": "0", "misplaced": "1", "correct": "2"}
|
|
|
|
|
31 |
char_result_map = map(
|
32 |
lambda char_res: result_number_map[char_res], firebase_result
|
33 |
)
|
34 |
+
return "".join(char_result_map)
|
35 |
|
36 |
def today(self):
|
37 |
+
return datetime.today().strftime("%Y%m%d")
|
38 |
|
39 |
def today_user_results(self):
|
40 |
+
daily_results_col = "dailyResults"
|
41 |
currentUser = self.get_user()
|
42 |
# Execute the query and get the first result
|
43 |
+
docs = (
|
44 |
+
self.db.collection(daily_results_col)
|
45 |
+
.where("user.email", "==", currentUser)
|
46 |
+
.where("date", "==", self.today())
|
47 |
+
.limit(1)
|
48 |
+
.get()
|
49 |
+
)
|
50 |
return docs
|
51 |
|
52 |
def today_user_attempts(self):
|
|
|
54 |
attempted_words = []
|
55 |
if len(docs) > 0:
|
56 |
doc = docs[0]
|
57 |
+
attempted_words = doc.to_dict().get("attemptedWords")
|
58 |
return attempted_words
|
59 |
|
60 |
def today_word(self):
|
61 |
+
words_col = "words"
|
62 |
doc = self.db.collection(words_col).document(self.today())
|
63 |
+
return doc.get().get("word")
|
rs_wordle_player/rs_wordle_player.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
from a3c.play import get_play_model_path, suggest
|
2 |
from wordle_env.wordle import get_env
|
|
|
3 |
from .firebase_connector import FirebaseConnector
|
4 |
from .selenium_player import SeleniumPlayer
|
5 |
|
@@ -17,7 +18,7 @@ def get_attempts(fb_connector):
|
|
17 |
|
18 |
def is_game_finished(states):
|
19 |
if states:
|
20 |
-
return states[-1] ==
|
21 |
return False
|
22 |
|
23 |
|
@@ -49,5 +50,5 @@ def play():
|
|
49 |
return words, won
|
50 |
|
51 |
|
52 |
-
if __name__ ==
|
53 |
print(play())
|
|
|
1 |
from a3c.play import get_play_model_path, suggest
|
2 |
from wordle_env.wordle import get_env
|
3 |
+
|
4 |
from .firebase_connector import FirebaseConnector
|
5 |
from .selenium_player import SeleniumPlayer
|
6 |
|
|
|
18 |
|
19 |
def is_game_finished(states):
|
20 |
if states:
|
21 |
+
return states[-1] == "22222" or len(states) == 6
|
22 |
return False
|
23 |
|
24 |
|
|
|
50 |
return words, won
|
51 |
|
52 |
|
53 |
+
if __name__ == "__main__":
|
54 |
print(play())
|
rs_wordle_player/selenium_player.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import os
|
2 |
import time
|
|
|
3 |
from dotenv import load_dotenv
|
4 |
from selenium import webdriver
|
5 |
from selenium.common.exceptions import UnexpectedAlertPresentException
|
@@ -8,8 +9,7 @@ from selenium.webdriver.common.by import By
|
|
8 |
from selenium.webdriver.common.keys import Keys
|
9 |
|
10 |
|
11 |
-
class SeleniumPlayer
|
12 |
-
|
13 |
def __init__(self):
|
14 |
self.wordle_url = self.get_wordle_url()
|
15 |
self.driver = self.get_driver()
|
@@ -24,22 +24,22 @@ class SeleniumPlayer():
|
|
24 |
|
25 |
def get_wordle_url(self):
|
26 |
load_dotenv()
|
27 |
-
return os.getenv(
|
28 |
|
29 |
def get_credentials(self):
|
30 |
load_dotenv()
|
31 |
-
username = os.getenv(
|
32 |
-
password = os.getenv(
|
33 |
return username, password
|
34 |
|
35 |
def logged_in(self):
|
36 |
-
return self.driver.current_url != self.wordle_url +
|
37 |
|
38 |
def log_in(self):
|
39 |
if not self.logged_in():
|
40 |
time.sleep(2)
|
41 |
-
login_div = self.driver.find_element(By.CLASS_NAME,
|
42 |
-
login_btns = login_div.find_elements(By.TAG_NAME,
|
43 |
login_btn = login_btns[0]
|
44 |
login_btn.click()
|
45 |
time.sleep(10)
|
@@ -47,32 +47,31 @@ class SeleniumPlayer():
|
|
47 |
login_window = self.driver.window_handles[1]
|
48 |
self.driver.switch_to.window(login_window)
|
49 |
username, password = self.get_credentials()
|
50 |
-
element = self.driver.find_element(By.ID,
|
51 |
element.send_keys(username)
|
52 |
element.send_keys(Keys.ENTER)
|
53 |
time.sleep(10)
|
54 |
-
element = self.driver.find_element(By.NAME,
|
55 |
element.send_keys(password)
|
56 |
element.send_keys(Keys.ENTER)
|
57 |
self.driver.switch_to.window(wordle_window)
|
58 |
time.sleep(5)
|
59 |
onboard_div = self.driver.find_element(
|
60 |
-
By.CLASS_NAME,
|
61 |
-
'onboarding-modal-container'
|
62 |
)
|
63 |
-
onboard_btn = onboard_div.find_elements(By.TAG_NAME,
|
64 |
onboard_btn[-1].click()
|
65 |
|
66 |
def play_word(self, word):
|
67 |
try:
|
68 |
-
element = self.driver.find_element(By.TAG_NAME,
|
69 |
# simulate typing the letters in the word into the input field
|
70 |
element.send_keys(word)
|
71 |
# simulate pressing the Enter key
|
72 |
element.send_keys(Keys.ENTER)
|
73 |
time.sleep(5)
|
74 |
except UnexpectedAlertPresentException:
|
75 |
-
print(
|
76 |
|
77 |
def finish(self):
|
78 |
self.driver.quit()
|
|
|
1 |
import os
|
2 |
import time
|
3 |
+
|
4 |
from dotenv import load_dotenv
|
5 |
from selenium import webdriver
|
6 |
from selenium.common.exceptions import UnexpectedAlertPresentException
|
|
|
9 |
from selenium.webdriver.common.keys import Keys
|
10 |
|
11 |
|
12 |
+
class SeleniumPlayer:
|
|
|
13 |
def __init__(self):
|
14 |
self.wordle_url = self.get_wordle_url()
|
15 |
self.driver = self.get_driver()
|
|
|
24 |
|
25 |
def get_wordle_url(self):
|
26 |
load_dotenv()
|
27 |
+
return os.getenv("RS_WORDLE_URL")
|
28 |
|
29 |
def get_credentials(self):
|
30 |
load_dotenv()
|
31 |
+
username = os.getenv("RS_WORDLE_USER")
|
32 |
+
password = os.getenv("RS_WORDLE_PASSWORD")
|
33 |
return username, password
|
34 |
|
35 |
def logged_in(self):
|
36 |
+
return self.driver.current_url != self.wordle_url + "/login"
|
37 |
|
38 |
def log_in(self):
|
39 |
if not self.logged_in():
|
40 |
time.sleep(2)
|
41 |
+
login_div = self.driver.find_element(By.CLASS_NAME, "login-button")
|
42 |
+
login_btns = login_div.find_elements(By.TAG_NAME, "button")
|
43 |
login_btn = login_btns[0]
|
44 |
login_btn.click()
|
45 |
time.sleep(10)
|
|
|
47 |
login_window = self.driver.window_handles[1]
|
48 |
self.driver.switch_to.window(login_window)
|
49 |
username, password = self.get_credentials()
|
50 |
+
element = self.driver.find_element(By.ID, "identifierId")
|
51 |
element.send_keys(username)
|
52 |
element.send_keys(Keys.ENTER)
|
53 |
time.sleep(10)
|
54 |
+
element = self.driver.find_element(By.NAME, "password")
|
55 |
element.send_keys(password)
|
56 |
element.send_keys(Keys.ENTER)
|
57 |
self.driver.switch_to.window(wordle_window)
|
58 |
time.sleep(5)
|
59 |
onboard_div = self.driver.find_element(
|
60 |
+
By.CLASS_NAME, "onboarding-modal-container"
|
|
|
61 |
)
|
62 |
+
onboard_btn = onboard_div.find_elements(By.TAG_NAME, "button")
|
63 |
onboard_btn[-1].click()
|
64 |
|
65 |
def play_word(self, word):
|
66 |
try:
|
67 |
+
element = self.driver.find_element(By.TAG_NAME, "html")
|
68 |
# simulate typing the letters in the word into the input field
|
69 |
element.send_keys(word)
|
70 |
# simulate pressing the Enter key
|
71 |
element.send_keys(Keys.ENTER)
|
72 |
time.sleep(5)
|
73 |
except UnexpectedAlertPresentException:
|
74 |
+
print("Won game alert on screen")
|
75 |
|
76 |
def finish(self):
|
77 |
self.driver.quit()
|
wordle_env/__init__.py
CHANGED
@@ -1,13 +1,9 @@
|
|
1 |
-
from gym.envs.registration import (
|
2 |
-
registry,
|
3 |
-
register,
|
4 |
-
make,
|
5 |
-
spec,
|
6 |
-
load_env_plugins as _load_env_plugins,
|
7 |
-
)
|
8 |
import os
|
9 |
-
from . import wordle
|
10 |
|
|
|
|
|
|
|
|
|
11 |
|
12 |
register(
|
13 |
id="WordleEnv100OneAction-v0",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
|
|
2 |
|
3 |
+
from gym.envs.registration import load_env_plugins as _load_env_plugins
|
4 |
+
from gym.envs.registration import make, register, registry, spec
|
5 |
+
|
6 |
+
from . import wordle
|
7 |
|
8 |
register(
|
9 |
id="WordleEnv100OneAction-v0",
|
wordle_env/const.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
WORDLE_CHARS =
|
2 |
WORDLE_N = 5
|
3 |
REWARD = 10
|
4 |
CHAR_REWARD = 0.1
|
|
|
1 |
+
WORDLE_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
2 |
WORDLE_N = 5
|
3 |
REWARD = 10
|
4 |
CHAR_REWARD = 0.1
|
wordle_env/state.py
CHANGED
@@ -13,11 +13,11 @@ where status has codes
|
|
13 |
"""
|
14 |
import collections
|
15 |
from typing import List, Tuple
|
|
|
16 |
import numpy as np
|
17 |
|
18 |
from .const import CHAR_REWARD, WORDLE_CHARS, WORDLE_N
|
19 |
|
20 |
-
|
21 |
WordleState = np.ndarray
|
22 |
|
23 |
|
@@ -27,8 +27,8 @@ def get_nvec(max_turns: int):
|
|
27 |
|
28 |
def new(max_turns: int) -> WordleState:
|
29 |
return np.array(
|
30 |
-
[max_turns] + [0, 0, 0] * WORDLE_N * len(WORDLE_CHARS),
|
31 |
-
|
32 |
|
33 |
|
34 |
def remaining_steps(state: WordleState) -> int:
|
@@ -40,11 +40,7 @@ SOMEWHERE = 1
|
|
40 |
YES = 2
|
41 |
|
42 |
|
43 |
-
def update_from_mask(
|
44 |
-
state: WordleState,
|
45 |
-
word: str,
|
46 |
-
mask: List[int]
|
47 |
-
) -> WordleState:
|
48 |
"""
|
49 |
return a copy of state that has been updated to new state
|
50 |
|
@@ -84,14 +80,14 @@ def update_from_mask(
|
|
84 |
# Need to check this first in case there's prior maybe + yes
|
85 |
if c in prior_maybe:
|
86 |
# Then the maybe could be anywhere except here
|
87 |
-
state[offset+3*i:offset+3*i+3] = [1, 0, 0]
|
88 |
elif c in prior_yes:
|
89 |
# No maybe, definitely a yes,
|
90 |
# so it's zero everywhere except the yesses
|
91 |
for j in range(WORDLE_N):
|
92 |
# Only flip no if previously was maybe
|
93 |
-
if state[offset + 3 * j:offset + 3 * j + 3][1] == 1:
|
94 |
-
state[offset + 3 * j:offset + 3 * j + 3] = [1, 0, 0]
|
95 |
else:
|
96 |
# Just straight up no
|
97 |
_set_all_no(state, offset)
|
@@ -115,7 +111,7 @@ def get_mask(word: str, goal_word: str) -> List[int]:
|
|
115 |
mask[i] = 1
|
116 |
counts[c] -= 1
|
117 |
else:
|
118 |
-
for j in range(i+1, len(mask)):
|
119 |
if mask[j] == 2:
|
120 |
continue
|
121 |
mask[j] = 0
|
@@ -136,11 +132,7 @@ def update_mask(state: WordleState, word: str, goal_word: str) -> WordleState:
|
|
136 |
return update_from_mask(state, word, mask)
|
137 |
|
138 |
|
139 |
-
def update(
|
140 |
-
state: WordleState,
|
141 |
-
word: str,
|
142 |
-
goal_word: str
|
143 |
-
) -> Tuple[WordleState, float]:
|
144 |
state = state.copy()
|
145 |
reward = 0
|
146 |
state[0] -= 1
|
@@ -158,8 +150,7 @@ def update(
|
|
158 |
cint = ord(c) - ord(WORDLE_CHARS[0])
|
159 |
offset = 1 + cint * WORDLE_N * 3
|
160 |
if goal_word[i] != c:
|
161 |
-
if
|
162 |
-
goal_word.count(c) > processed_letters.count(c)):
|
163 |
# Char at position i = no,
|
164 |
# and in other positions maybe except it had a value before,
|
165 |
# other chars stay as they are
|
@@ -184,27 +175,27 @@ def _set_if_cero(state, offset, value):
|
|
184 |
# but only if it didnt have a value before
|
185 |
for char_idx in range(0, WORDLE_N * 3, 3):
|
186 |
char_offset = offset + char_idx
|
187 |
-
if tuple(state[char_offset: char_offset + 3]) == (0, 0, 0):
|
188 |
-
state[char_offset: char_offset + 3] = value
|
189 |
|
190 |
|
191 |
def _set_yes(state, offset, char_int, char_pos):
|
192 |
# char at position char_pos = yes,
|
193 |
# all other chars at position char_pos == no
|
194 |
pos_offset = 3 * char_pos
|
195 |
-
state[offset + pos_offset:offset + pos_offset + 3] = [0, 0, 1]
|
196 |
for ocint in range(len(WORDLE_CHARS)):
|
197 |
if ocint != char_int:
|
198 |
oc_offset = 1 + ocint * WORDLE_N * 3
|
199 |
yes_index = oc_offset + pos_offset
|
200 |
-
state[yes_index:yes_index + 3] = [1, 0, 0]
|
201 |
|
202 |
|
203 |
def _set_no(state, offset, char_pos):
|
204 |
# Set offset character = no at char_pos position
|
205 |
-
state[offset + 3 * char_pos:offset + 3 * char_pos + 3] = [1, 0, 0]
|
206 |
|
207 |
|
208 |
def _set_all_no(state, offset):
|
209 |
# Set offset character = no at all positions
|
210 |
-
state[offset:offset + 3 * WORDLE_N] = [1, 0, 0] * WORDLE_N
|
|
|
13 |
"""
|
14 |
import collections
|
15 |
from typing import List, Tuple
|
16 |
+
|
17 |
import numpy as np
|
18 |
|
19 |
from .const import CHAR_REWARD, WORDLE_CHARS, WORDLE_N
|
20 |
|
|
|
21 |
WordleState = np.ndarray
|
22 |
|
23 |
|
|
|
27 |
|
28 |
def new(max_turns: int) -> WordleState:
|
29 |
return np.array(
|
30 |
+
[max_turns] + [0, 0, 0] * WORDLE_N * len(WORDLE_CHARS), dtype=np.int32
|
31 |
+
)
|
32 |
|
33 |
|
34 |
def remaining_steps(state: WordleState) -> int:
|
|
|
40 |
YES = 2
|
41 |
|
42 |
|
43 |
+
def update_from_mask(state: WordleState, word: str, mask: List[int]) -> WordleState:
|
|
|
|
|
|
|
|
|
44 |
"""
|
45 |
return a copy of state that has been updated to new state
|
46 |
|
|
|
80 |
# Need to check this first in case there's prior maybe + yes
|
81 |
if c in prior_maybe:
|
82 |
# Then the maybe could be anywhere except here
|
83 |
+
state[offset + 3 * i : offset + 3 * i + 3] = [1, 0, 0]
|
84 |
elif c in prior_yes:
|
85 |
# No maybe, definitely a yes,
|
86 |
# so it's zero everywhere except the yesses
|
87 |
for j in range(WORDLE_N):
|
88 |
# Only flip no if previously was maybe
|
89 |
+
if state[offset + 3 * j : offset + 3 * j + 3][1] == 1:
|
90 |
+
state[offset + 3 * j : offset + 3 * j + 3] = [1, 0, 0]
|
91 |
else:
|
92 |
# Just straight up no
|
93 |
_set_all_no(state, offset)
|
|
|
111 |
mask[i] = 1
|
112 |
counts[c] -= 1
|
113 |
else:
|
114 |
+
for j in range(i + 1, len(mask)):
|
115 |
if mask[j] == 2:
|
116 |
continue
|
117 |
mask[j] = 0
|
|
|
132 |
return update_from_mask(state, word, mask)
|
133 |
|
134 |
|
135 |
+
def update(state: WordleState, word: str, goal_word: str) -> Tuple[WordleState, float]:
|
|
|
|
|
|
|
|
|
136 |
state = state.copy()
|
137 |
reward = 0
|
138 |
state[0] -= 1
|
|
|
150 |
cint = ord(c) - ord(WORDLE_CHARS[0])
|
151 |
offset = 1 + cint * WORDLE_N * 3
|
152 |
if goal_word[i] != c:
|
153 |
+
if c in goal_word and goal_word.count(c) > processed_letters.count(c):
|
|
|
154 |
# Char at position i = no,
|
155 |
# and in other positions maybe except it had a value before,
|
156 |
# other chars stay as they are
|
|
|
175 |
# but only if it didnt have a value before
|
176 |
for char_idx in range(0, WORDLE_N * 3, 3):
|
177 |
char_offset = offset + char_idx
|
178 |
+
if tuple(state[char_offset : char_offset + 3]) == (0, 0, 0):
|
179 |
+
state[char_offset : char_offset + 3] = value
|
180 |
|
181 |
|
182 |
def _set_yes(state, offset, char_int, char_pos):
|
183 |
# char at position char_pos = yes,
|
184 |
# all other chars at position char_pos == no
|
185 |
pos_offset = 3 * char_pos
|
186 |
+
state[offset + pos_offset : offset + pos_offset + 3] = [0, 0, 1]
|
187 |
for ocint in range(len(WORDLE_CHARS)):
|
188 |
if ocint != char_int:
|
189 |
oc_offset = 1 + ocint * WORDLE_N * 3
|
190 |
yes_index = oc_offset + pos_offset
|
191 |
+
state[yes_index : yes_index + 3] = [1, 0, 0]
|
192 |
|
193 |
|
194 |
def _set_no(state, offset, char_pos):
|
195 |
# Set offset character = no at char_pos position
|
196 |
+
state[offset + 3 * char_pos : offset + 3 * char_pos + 3] = [1, 0, 0]
|
197 |
|
198 |
|
199 |
def _set_all_no(state, offset):
|
200 |
# Set offset character = no at all positions
|
201 |
+
state[offset : offset + 3 * WORDLE_N] = [1, 0, 0] * WORDLE_N
|
wordle_env/test_wordle.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import pytest
|
2 |
|
3 |
-
from . import wordle
|
4 |
-
from . import state
|
5 |
|
6 |
TESTWORDS = [
|
7 |
"APPAA",
|
|
|
1 |
import pytest
|
2 |
|
3 |
+
from . import state, wordle
|
|
|
4 |
|
5 |
TESTWORDS = [
|
6 |
"APPAA",
|
wordle_env/wordle.py
CHANGED
@@ -1,24 +1,22 @@
|
|
|
|
|
|
|
|
1 |
import gym
|
2 |
from gym import spaces
|
3 |
-
from typing import Optional, List
|
4 |
|
5 |
from . import state
|
6 |
-
from .const import
|
7 |
-
|
8 |
from .words import complete_vocabulary, target_vocabulary
|
9 |
|
10 |
-
import random
|
11 |
-
|
12 |
|
13 |
def _load_words(
|
14 |
-
limit: Optional[int] = None,
|
15 |
-
complete: Optional[bool] = False
|
16 |
) -> List[str]:
|
17 |
words = complete_vocabulary if complete else target_vocabulary
|
18 |
return words if not limit else words[:limit]
|
19 |
|
20 |
|
21 |
-
def get_env(env_id=
|
22 |
return gym.make(env_id)
|
23 |
|
24 |
|
@@ -42,13 +40,16 @@ class WordleEnvBase(gym.Env):
|
|
42 |
Initial state with turn 0, all chars Unvisited
|
43 |
"""
|
44 |
|
45 |
-
def __init__(
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
49 |
assert all(
|
50 |
len(w) == WORDLE_N for w in words
|
51 |
-
), f
|
52 |
self.words = words
|
53 |
self.max_turns = max_turns
|
54 |
self.allowable_words = allowable_words
|
@@ -57,8 +58,7 @@ class WordleEnvBase(gym.Env):
|
|
57 |
self.allowable_words = len(self.words)
|
58 |
|
59 |
self.action_space = spaces.Discrete(self.words_as_action_space())
|
60 |
-
self.observation_space = spaces.MultiDiscrete(
|
61 |
-
state.get_nvec(self.max_turns))
|
62 |
|
63 |
self.done = True
|
64 |
self.goal_word: int = -1
|
@@ -79,15 +79,15 @@ class WordleEnvBase(gym.Env):
|
|
79 |
word = self.words[action]
|
80 |
goal_word = self.words[self.goal_word]
|
81 |
# assert word in self.words, f'{word} not in words list'
|
82 |
-
self.state, r = self.state_updater(
|
83 |
-
|
84 |
-
|
85 |
|
86 |
reward = r
|
87 |
if action == self.goal_word:
|
88 |
self.done = True
|
89 |
# reward = REWARD
|
90 |
-
if state.remaining_steps(self.state) == self.max_turns-1:
|
91 |
reward = 0 # -10*REWARD # No reward for guessing off the bat
|
92 |
else:
|
93 |
reward = REWARD
|
@@ -100,7 +100,7 @@ class WordleEnvBase(gym.Env):
|
|
100 |
def reset(self):
|
101 |
self.state = state.new(self.max_turns)
|
102 |
self.done = False
|
103 |
-
random_word = random.choice(self.words[:self.allowable_words])
|
104 |
self.goal_word = self.words.index(random_word)
|
105 |
return self.state.copy()
|
106 |
|
@@ -121,8 +121,7 @@ class WordleEnv100OneAction(WordleEnvBase):
|
|
121 |
|
122 |
class WordleEnv100WithMask(WordleEnvBase):
|
123 |
def __init__(self):
|
124 |
-
super().__init__(words=_load_words(100),
|
125 |
-
mask_based_state_updates=True)
|
126 |
|
127 |
|
128 |
class WordleEnv100TwoAction(WordleEnvBase):
|
@@ -142,8 +141,7 @@ class WordleEnv100FullAction(WordleEnvBase):
|
|
142 |
|
143 |
class WordleEnv1000WithMask(WordleEnvBase):
|
144 |
def __init__(self):
|
145 |
-
super().__init__(words=_load_words(1000),
|
146 |
-
mask_based_state_updates=True)
|
147 |
|
148 |
|
149 |
class WordleEnv1000FullAction(WordleEnvBase):
|
@@ -158,5 +156,6 @@ class WordleEnvFull(WordleEnvBase):
|
|
158 |
|
159 |
class WordleEnvRealWithMask(WordleEnvBase):
|
160 |
def __init__(self):
|
161 |
-
super().__init__(
|
162 |
-
|
|
|
|
1 |
+
import random
|
2 |
+
from typing import List, Optional
|
3 |
+
|
4 |
import gym
|
5 |
from gym import spaces
|
|
|
6 |
|
7 |
from . import state
|
8 |
+
from .const import REWARD, WORDLE_CHARS, WORDLE_N
|
|
|
9 |
from .words import complete_vocabulary, target_vocabulary
|
10 |
|
|
|
|
|
11 |
|
12 |
def _load_words(
|
13 |
+
limit: Optional[int] = None, complete: Optional[bool] = False
|
|
|
14 |
) -> List[str]:
|
15 |
words = complete_vocabulary if complete else target_vocabulary
|
16 |
return words if not limit else words[:limit]
|
17 |
|
18 |
|
19 |
+
def get_env(env_id="WordleEnvFull-v0"):
|
20 |
return gym.make(env_id)
|
21 |
|
22 |
|
|
|
40 |
Initial state with turn 0, all chars Unvisited
|
41 |
"""
|
42 |
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
words: List[str],
|
46 |
+
max_turns: int = 6,
|
47 |
+
allowable_words: Optional[int] = None,
|
48 |
+
mask_based_state_updates: bool = False,
|
49 |
+
):
|
50 |
assert all(
|
51 |
len(w) == WORDLE_N for w in words
|
52 |
+
), f"Not all words of length {WORDLE_N}, {words}"
|
53 |
self.words = words
|
54 |
self.max_turns = max_turns
|
55 |
self.allowable_words = allowable_words
|
|
|
58 |
self.allowable_words = len(self.words)
|
59 |
|
60 |
self.action_space = spaces.Discrete(self.words_as_action_space())
|
61 |
+
self.observation_space = spaces.MultiDiscrete(state.get_nvec(self.max_turns))
|
|
|
62 |
|
63 |
self.done = True
|
64 |
self.goal_word: int = -1
|
|
|
79 |
word = self.words[action]
|
80 |
goal_word = self.words[self.goal_word]
|
81 |
# assert word in self.words, f'{word} not in words list'
|
82 |
+
self.state, r = self.state_updater(
|
83 |
+
state=self.state, word=word, goal_word=goal_word
|
84 |
+
)
|
85 |
|
86 |
reward = r
|
87 |
if action == self.goal_word:
|
88 |
self.done = True
|
89 |
# reward = REWARD
|
90 |
+
if state.remaining_steps(self.state) == self.max_turns - 1:
|
91 |
reward = 0 # -10*REWARD # No reward for guessing off the bat
|
92 |
else:
|
93 |
reward = REWARD
|
|
|
100 |
def reset(self):
|
101 |
self.state = state.new(self.max_turns)
|
102 |
self.done = False
|
103 |
+
random_word = random.choice(self.words[: self.allowable_words])
|
104 |
self.goal_word = self.words.index(random_word)
|
105 |
return self.state.copy()
|
106 |
|
|
|
121 |
|
122 |
class WordleEnv100WithMask(WordleEnvBase):
|
123 |
def __init__(self):
|
124 |
+
super().__init__(words=_load_words(100), mask_based_state_updates=True)
|
|
|
125 |
|
126 |
|
127 |
class WordleEnv100TwoAction(WordleEnvBase):
|
|
|
141 |
|
142 |
class WordleEnv1000WithMask(WordleEnvBase):
|
143 |
def __init__(self):
|
144 |
+
super().__init__(words=_load_words(1000), mask_based_state_updates=True)
|
|
|
145 |
|
146 |
|
147 |
class WordleEnv1000FullAction(WordleEnvBase):
|
|
|
156 |
|
157 |
class WordleEnvRealWithMask(WordleEnvBase):
|
158 |
def __init__(self):
|
159 |
+
super().__init__(
|
160 |
+
words=_load_words(), allowable_words=2315, mask_based_state_updates=True
|
161 |
+
)
|
wordle_env/words.py
CHANGED
@@ -7,7 +7,7 @@ _COMPLETE_VOCABULARY_URL = "https://gist.githubusercontent.com/scholtes/\
|
|
7 |
_TARGET_VOCABULARY_URL = "https://gist.githubusercontent.com/scholtes/\
|
8 |
94f3c0303ba6a7768b47583aff36654d/raw/\
|
9 |
d9cddf5e16140df9e14f19c2de76a0ef36fd2748/wordle-La.txt"
|
10 |
-
_DOWNLOADS_DIR =
|
11 |
_COMPLETE_VOCABULARY_FILENAME = "complete_vocabulary.txt"
|
12 |
_TARGET_VOCABULARY_FILENAME = "target_vocabulary.txt"
|
13 |
|
@@ -24,7 +24,11 @@ def _retrieve_vocabulary(url, filename, dir):
|
|
24 |
|
25 |
|
26 |
target_vocabulary = _retrieve_vocabulary(
|
27 |
-
_TARGET_VOCABULARY_URL, _TARGET_VOCABULARY_FILENAME, _DOWNLOADS_DIR
|
28 |
-
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
|
|
|
7 |
_TARGET_VOCABULARY_URL = "https://gist.githubusercontent.com/scholtes/\
|
8 |
94f3c0303ba6a7768b47583aff36654d/raw/\
|
9 |
d9cddf5e16140df9e14f19c2de76a0ef36fd2748/wordle-La.txt"
|
10 |
+
_DOWNLOADS_DIR = "."
|
11 |
_COMPLETE_VOCABULARY_FILENAME = "complete_vocabulary.txt"
|
12 |
_TARGET_VOCABULARY_FILENAME = "target_vocabulary.txt"
|
13 |
|
|
|
24 |
|
25 |
|
26 |
target_vocabulary = _retrieve_vocabulary(
|
27 |
+
_TARGET_VOCABULARY_URL, _TARGET_VOCABULARY_FILENAME, _DOWNLOADS_DIR
|
28 |
+
)
|
29 |
+
complete_vocabulary = (
|
30 |
+
_retrieve_vocabulary(
|
31 |
+
_COMPLETE_VOCABULARY_URL, _COMPLETE_VOCABULARY_FILENAME, _DOWNLOADS_DIR
|
32 |
+
)
|
33 |
+
+ target_vocabulary
|
34 |
+
)
|
wordle_game.py
CHANGED
@@ -1,30 +1,28 @@
|
|
1 |
-
from rich.prompt import Prompt
|
2 |
-
from rich.console import Console
|
3 |
from random import choice
|
4 |
-
from wordle_env.words import target_vocabulary, complete_vocabulary
|
5 |
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
11 |
|
12 |
-
WELCOME_MESSAGE = f
|
13 |
PLAYER_INSTRUCTIONS = "You may start guessing\n"
|
14 |
GUESS_STATEMENT = "\nEnter your guess"
|
15 |
ALLOWED_GUESSES = 6
|
16 |
|
17 |
|
18 |
def correct_place(letter):
|
19 |
-
return f
|
20 |
|
21 |
|
22 |
def correct_letter(letter):
|
23 |
-
return f
|
24 |
|
25 |
|
26 |
def incorrect_letter(letter):
|
27 |
-
return f
|
28 |
|
29 |
|
30 |
def check_guess(guess, answer):
|
@@ -34,19 +32,20 @@ def check_guess(guess, answer):
|
|
34 |
for i, letter in enumerate(guess):
|
35 |
if answer[i] == guess[i]:
|
36 |
guessed[i] = correct_place(letter)
|
37 |
-
wordle_pattern.append(SQUARES[
|
38 |
processed_letters.append(letter)
|
39 |
for i, letter in enumerate(guess):
|
40 |
if answer[i] != guess[i]:
|
41 |
-
if
|
42 |
-
|
|
|
43 |
guessed[i] = correct_letter(letter)
|
44 |
-
wordle_pattern.append(SQUARES[
|
45 |
else:
|
46 |
guessed[i] = incorrect_letter(letter)
|
47 |
-
wordle_pattern.append(SQUARES[
|
48 |
processed_letters.append(letter)
|
49 |
-
return
|
50 |
|
51 |
|
52 |
def game(console, chosen_word):
|
@@ -57,12 +56,15 @@ def game(console, chosen_word):
|
|
57 |
|
58 |
while not end_of_game:
|
59 |
guess = Prompt.ask(GUESS_STATEMENT).upper()
|
60 |
-
while (
|
61 |
-
|
|
|
|
|
|
|
62 |
if guess in already_guessed:
|
63 |
console.print("[red]You've already guessed this word!!\n[/]")
|
64 |
else:
|
65 |
-
console.print(
|
66 |
guess = Prompt.ask(GUESS_STATEMENT).upper()
|
67 |
already_guessed.append(guess)
|
68 |
guessed, pattern = check_guess(guess, chosen_word)
|
@@ -74,14 +76,13 @@ def game(console, chosen_word):
|
|
74 |
end_of_game = True
|
75 |
if len(already_guessed) == ALLOWED_GUESSES and guess != chosen_word:
|
76 |
console.print(f"\n[red]WORDLE X/{ALLOWED_GUESSES}[/]")
|
77 |
-
console.print(f
|
78 |
else:
|
79 |
-
console.print(
|
80 |
-
f"\n[green]WORDLE {len(already_guessed)}/{ALLOWED_GUESSES}[/]\n")
|
81 |
console.print(*full_wordle_pattern, sep="\n")
|
82 |
|
83 |
|
84 |
-
if __name__ ==
|
85 |
console = Console()
|
86 |
chosen_word = choice(target_vocabulary)
|
87 |
console.print(WELCOME_MESSAGE)
|
|
|
|
|
|
|
1 |
from random import choice
|
|
|
2 |
|
3 |
+
from rich.console import Console
|
4 |
+
from rich.prompt import Prompt
|
5 |
+
|
6 |
+
from wordle_env.words import complete_vocabulary, target_vocabulary
|
7 |
+
|
8 |
+
SQUARES = {"correct_place": "π©", "correct_letter": "π¨", "incorrect_letter": "β¬"}
|
9 |
|
10 |
+
WELCOME_MESSAGE = f"\n[white on blue] WELCOME TO WORDLE [/]\n"
|
11 |
PLAYER_INSTRUCTIONS = "You may start guessing\n"
|
12 |
GUESS_STATEMENT = "\nEnter your guess"
|
13 |
ALLOWED_GUESSES = 6
|
14 |
|
15 |
|
16 |
def correct_place(letter):
|
17 |
+
return f"[black on green]{letter}[/]"
|
18 |
|
19 |
|
20 |
def correct_letter(letter):
|
21 |
+
return f"[black on yellow]{letter}[/]"
|
22 |
|
23 |
|
24 |
def incorrect_letter(letter):
|
25 |
+
return f"[black on white]{letter}[/]"
|
26 |
|
27 |
|
28 |
def check_guess(guess, answer):
|
|
|
32 |
for i, letter in enumerate(guess):
|
33 |
if answer[i] == guess[i]:
|
34 |
guessed[i] = correct_place(letter)
|
35 |
+
wordle_pattern.append(SQUARES["correct_place"])
|
36 |
processed_letters.append(letter)
|
37 |
for i, letter in enumerate(guess):
|
38 |
if answer[i] != guess[i]:
|
39 |
+
if letter in answer and answer.count(letter) > processed_letters.count(
|
40 |
+
letter
|
41 |
+
):
|
42 |
guessed[i] = correct_letter(letter)
|
43 |
+
wordle_pattern.append(SQUARES["correct_letter"])
|
44 |
else:
|
45 |
guessed[i] = incorrect_letter(letter)
|
46 |
+
wordle_pattern.append(SQUARES["incorrect_letter"])
|
47 |
processed_letters.append(letter)
|
48 |
+
return "".join(guessed), "".join(wordle_pattern)
|
49 |
|
50 |
|
51 |
def game(console, chosen_word):
|
|
|
56 |
|
57 |
while not end_of_game:
|
58 |
guess = Prompt.ask(GUESS_STATEMENT).upper()
|
59 |
+
while (
|
60 |
+
len(guess) != 5
|
61 |
+
or guess in already_guessed
|
62 |
+
or guess not in complete_vocabulary
|
63 |
+
):
|
64 |
if guess in already_guessed:
|
65 |
console.print("[red]You've already guessed this word!!\n[/]")
|
66 |
else:
|
67 |
+
console.print("[red]Please enter a valid 5-letter word!!\n[/]")
|
68 |
guess = Prompt.ask(GUESS_STATEMENT).upper()
|
69 |
already_guessed.append(guess)
|
70 |
guessed, pattern = check_guess(guess, chosen_word)
|
|
|
76 |
end_of_game = True
|
77 |
if len(already_guessed) == ALLOWED_GUESSES and guess != chosen_word:
|
78 |
console.print(f"\n[red]WORDLE X/{ALLOWED_GUESSES}[/]")
|
79 |
+
console.print(f"\n[green]Correct Word: {chosen_word}[/]")
|
80 |
else:
|
81 |
+
console.print(f"\n[green]WORDLE {len(already_guessed)}/{ALLOWED_GUESSES}[/]\n")
|
|
|
82 |
console.print(*full_wordle_pattern, sep="\n")
|
83 |
|
84 |
|
85 |
+
if __name__ == "__main__":
|
86 |
console = Console()
|
87 |
chosen_word = choice(target_vocabulary)
|
88 |
console.print(WELCOME_MESSAGE)
|