fhudi commited on
Commit
8bf595d
·
verified ·
1 Parent(s): 8f5315f

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ textgames-scrabble-black2-ss.png filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -1,10 +1,9 @@
1
  */*.DS_Store
2
  .DS_Store
3
 
4
- ssl/
5
- problemsets_*
6
-
7
  user_outputs/
 
8
 
9
  .idea/
10
 
 
1
  */*.DS_Store
2
  .DS_Store
3
 
4
+ agents/*.sh
 
 
5
  user_outputs/
6
+ model_outputs/__runs__
7
 
8
  .idea/
9
 
agents/Gemma-2-9b-it.ipynb ADDED
@@ -0,0 +1 @@
 
 
1
+ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":1,"metadata":{"id":"Rli_enT6lBDT","executionInfo":{"status":"ok","timestamp":1737395007014,"user_tz":-540,"elapsed":5212,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}}},"outputs":[],"source":["##%%\n","import os\n","import pickle\n","import json\n","# import random\n","# import torch\n","# import numpy as np\n","# import argparse\n","# import cohere\n","# from openai import OpenAI\n"]},{"cell_type":"code","source":["##%%\n","# import hashlib\n","from tqdm import tqdm\n","from itertools import product\n","# from collections import Counter\n","\n","# from transformers import LlamaForCausalLM, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM\n","from transformers import AutoTokenizer, AutoModelForCausalLM\n","from textgames import GAME_NAMES, LEVEL_IDS, game_filename, _game_class_from_name\n"],"metadata":{"id":"dp1F32B8oSfD","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1737395010583,"user_tz":-540,"elapsed":3547,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}},"outputId":"e9adeb5f-70eb-4ca9-dcbb-428e4b28ab41"},"execution_count":2,"outputs":[{"output_type":"stream","name":"stderr","text":["/home/is/frederikus-h/miniconda3/envs/textgame/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n"," from .autonotebook import tqdm as notebook_tqdm\n"]}]},{"cell_type":"code","source":["os.environ.setdefault(\"TEXTGAMES_OUTPUT_DIR\", \"user_outputs\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"2wEu1V1wvxn0","executionInfo":{"status":"ok","timestamp":1737395010664,"user_tz":-540,"elapsed":67,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}},"outputId":"cdcad20f-e357-4009-9f4f-0d4495ebd894"},"execution_count":3,"outputs":[{"output_type":"execute_result","data":{"text/plain":["'user_outputs'"]},"metadata":{},"execution_count":3}]},{"cell_type":"code","source":["##%%\n","gen_model_checkpoint = \"google/gemma-2-9b-it\"\n","quantize = True"],"metadata":{"id":"jZF8bkUcojTX","executionInfo":{"status":"ok","timestamp":1737395010678,"user_tz":-540,"elapsed":13,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}}},"execution_count":4,"outputs":[]},{"cell_type":"code","source":["kwargs = {\n"," \"device_map\": \"auto\",\n","} if quantize else {}"],"metadata":{"id":"VAF5sR9arYzS","executionInfo":{"status":"ok","timestamp":1737395010683,"user_tz":-540,"elapsed":2,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}}},"execution_count":5,"outputs":[]},{"cell_type":"code","source":["##%%\n","gen_model = AutoModelForCausalLM.from_pretrained(gen_model_checkpoint, **kwargs)\n","tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, **kwargs)"],"metadata":{"id":"tzqldl8ooRVL","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1737395038547,"user_tz":-540,"elapsed":27859,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}},"outputId":"902b638c-e6ce-4f8a-bba2-e9f7241c9a27"},"execution_count":6,"outputs":[{"output_type":"stream","name":"stderr","text":["Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:24<00:00, 6.19s/it]\n"]}]},{"cell_type":"code","source":["gen_model.device"],"metadata":{"id":"FeBUXdkWsWrL","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1737395038552,"user_tz":-540,"elapsed":3,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}},"outputId":"6437d1b7-02f8-47f5-d519-e979cefde795"},"execution_count":7,"outputs":[{"output_type":"execute_result","data":{"text/plain":["device(type='cuda', index=0)"]},"metadata":{},"execution_count":7}]},{"cell_type":"code","source":["def get_gemma_response(text):\n"," # global gen_model, tokenizer\n"," messages = [\n"," {\"role\": \"user\", \"content\": text},\n"," ]\n","\n"," input_ids = tokenizer.apply_chat_template(\n"," messages,\n"," add_generation_prompt=True,\n"," return_tensors=\"pt\"\n"," ).to(gen_model.device)\n","\n"," terminators = [\n"," tokenizer.eos_token_id,\n"," tokenizer.convert_tokens_to_ids(\"<|eot_id|>\")\n"," ]\n","\n"," outputs = gen_model.generate(\n"," input_ids,\n"," max_new_tokens=100,\n"," eos_token_id=terminators,\n"," do_sample=True,\n"," temperature=.001,\n"," top_p=1,\n"," )\n","\n"," response = outputs[0][input_ids.shape[-1]:]\n"," return tokenizer.decode(response, skip_special_tokens=True)"],"metadata":{"id":"R5D4K-P2sPaj","executionInfo":{"status":"ok","timestamp":1737395038554,"user_tz":-540,"elapsed":1,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}}},"execution_count":8,"outputs":[]},{"cell_type":"markdown","source":["---\n","Example Call"],"metadata":{"id":"s5FEwOOvxf4h"}},{"cell_type":"code","source":["# @title\n","text = \\\n","\"\"\"\n","Given a set of rules to calculate point, sort the set of words in decreasing order.\n","When there 2 or more words with same point, sort lexicographically.\n","\n","Rules:\n","- every pair of consecutive consonant gets 5 points\n","- every pair of consecutive vowel gets 3 points\n","- add 1 point if there exists exactly 1 'g' in the word\n","- word less than 5 characters gets 10 points\n","- word starts with gen gets 100 points\n","- word ends with ta gets -1000 point\n","\n","Words:\n","- genta\n","- winata\n","- hudi\n","- alham\n","- aji\n","- ruochen\n","\n","Print only the answer.\n","\"\"\"\n","\n","print(text)"],"metadata":{"id":"T_tk4hTGsxsR","colab":{"base_uri":"https://localhost:8080/"},"cellView":"form","executionInfo":{"status":"ok","timestamp":1737392776367,"user_tz":-540,"elapsed":27,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}},"outputId":"d5ea884f-d0fa-4134-ecd9-690eab51c976"},"execution_count":14,"outputs":[{"output_type":"stream","name":"stdout","text":["\n","Given a set of rules to calculate point, sort the set of words in decreasing order.\n","When there 2 or more words with same point, sort lexicographically.\n","\n","Rules:\n","- every pair of consecutive consonant gets 5 points\n","- every pair of consecutive vowel gets 3 points\n","- add 1 point if there exists exactly 1 'g' in the word\n","- word less than 5 characters gets 10 points\n","- word starts with gen gets 100 points\n","- word ends with ta gets -1000 point\n","\n","Words:\n","- genta\n","- winata\n","- hudi\n","- alham\n","- aji\n","- ruochen\n","\n","Print only the answer.\n","\n"]}]},{"cell_type":"code","source":["# Gold Answer:\n","# - aji 10\n","# - hudi 10\n","# - ruochen 5 3\n","# - alham 5\n","# - genta 5 1 100 -1000\n","# - winata -1000"],"metadata":{"id":"G-5yS4S-rdsN"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["print(get_gemma_response(text))"],"metadata":{"id":"05OI36v6vGoY","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1737392724119,"user_tz":-540,"elapsed":6741,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}},"outputId":"fe5d6ed2-d063-4f1c-b2e1-b3af8dbc456e"},"execution_count":9,"outputs":[{"output_type":"stream","name":"stdout","text":["genta\n","winata\n","ruochen\n","hudi\n","alham\n","aji \n","\n"]}]},{"cell_type":"markdown","source":["---\n","Automate run all sessions"],"metadata":{"id":"cxJ4WqHpxi75"}},{"cell_type":"code","source":["for game_name, difficulty_level in product([GAME_NAMES[4], *GAME_NAMES[:4], *GAME_NAMES[5:]], LEVEL_IDS[:3]):\n"," game_cls = _game_class_from_name(game_name)\n"," with open(f\"problemsets/{game_filename(game_name)}_{difficulty_level}.json\", \"r\", encoding=\"utf8\") as f:\n"," sid_prompt_dict = json.load(f)\n","\n"," correct_cnt = 0\n"," for sid, prompt in tqdm(list(sid_prompt_dict.items()), desc=f\"{game_filename(game_name)}_-_{difficulty_level}\"):\n"," cur_game = game_cls()\n"," cur_game.load_game(prompt)\n"," response = get_gemma_response(cur_game.get_prompt()).strip()\n"," solved, val_msg = cur_game.validate(response)\n"," with open(f\"model_outputs/results_gemma_2_9B_it.pkl\", \"ab\") as o:\n"," pickle.dump((f\"{game_filename(game_name)}_{difficulty_level}\", sid, response, (solved, val_msg)), o)\n"," if solved:\n"," correct_cnt += 1\n","\n"," print(f\"{game_name}_-_{difficulty_level}\")\n"," print(f\" Acc.: {correct_cnt / len(sid_prompt_dict):.2%}\")"],"metadata":{"id":"hCTXYpXa1UQ6"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"GC-zkVI52IJX"},"execution_count":null,"outputs":[]}]}
agents/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Define the __all__ variable
2
+ __all__ = ["run_with_agent"]
3
+
4
+ # Import the submodules
5
+ from .runner import run_with_agent
agents/_reference.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import random
4
+ import numpy as np
5
+ import argparse
6
+ import json
7
+ import cohere
8
+ from openai import OpenAI
9
+
10
+ from tqdm import tqdm
11
+
12
+ from collections import Counter
13
+
14
+ from transformers import LlamaForCausalLM, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
15
+ import hashlib
16
+
17
+
18
+ OPENAI_TOKEN = ""
19
+ COHERE_TOKEN = ""
20
+ HF_TOKEN = ""
21
+
22
+
23
+ def argmax(array):
24
+ """argmax with deterministic pseudorandom tie breaking."""
25
+ max_indices = np.arange(len(array))[array == np.max(array)]
26
+ idx = int(hashlib.sha256(np.asarray(array).tobytes()).hexdigest(), 16) % len(max_indices)
27
+ return max_indices[idx]
28
+
29
+
30
+ def logsumexp(x):
31
+ c = x.max()
32
+ return c + np.log(np.sum(np.exp(x - c)))
33
+
34
+
35
+ def normalize(x):
36
+ x = np.array(x)
37
+ return np.exp(x - logsumexp(x))
38
+
39
+
40
+ def set_seed(seed):
41
+ random.seed(seed)
42
+ np.random.seed(seed)
43
+ torch.manual_seed(seed)
44
+ torch.cuda.manual_seed(seed)
45
+
46
+
47
+ def get_commandr_chat_response(gen_model, gen_model_checkpoint, text, seed):
48
+ response = gen_model.chat(
49
+ model="command-r",
50
+ message=text,
51
+ temperature=0,
52
+ max_tokens=64,
53
+ seed=seed,
54
+ p=1
55
+ )
56
+ return response.text
57
+
58
+
59
+ def get_mt0_response(gen_model, tokenizer, gen_model_checkpoint, text, seed):
60
+ input_ids = tokenizer.encode(text, return_tensors="pt").to(gen_model.device)
61
+
62
+ outputs = gen_model.generate(
63
+ input_ids,
64
+ max_new_tokens=10,
65
+ do_sample=True,
66
+ temperature=0.2,
67
+ top_p=1
68
+ )
69
+
70
+ response = outputs[0]
71
+ return tokenizer.decode(response, skip_special_tokens=True)
72
+
73
+
74
+ def get_gemma_response(gen_model, tokenizer, gen_model_checkpoint, text, seed):
75
+ messages = [
76
+ {"role": "user", "content": text},
77
+ ]
78
+
79
+ input_ids = tokenizer.apply_chat_template(
80
+ messages,
81
+ add_generation_prompt=True,
82
+ return_tensors="pt"
83
+ ).to(gen_model.device)
84
+
85
+ terminators = [
86
+ tokenizer.eos_token_id,
87
+ tokenizer.convert_tokens_to_ids("<|eot_id|>")
88
+ ]
89
+
90
+ outputs = gen_model.generate(
91
+ input_ids,
92
+ max_new_tokens=10,
93
+ eos_token_id=terminators,
94
+ do_sample=True,
95
+ temperature=0.2,
96
+ top_p=1
97
+ )
98
+
99
+ response = outputs[0][input_ids.shape[-1]:]
100
+ return tokenizer.decode(response, skip_special_tokens=True)
101
+
102
+
103
+ def get_mistral_instruct_chat_response(gen_model, tokenizer, gen_model_checkpoint, text, seed):
104
+ messages = [
105
+ {"role": "user", "content": text},
106
+ ]
107
+
108
+ input_ids = tokenizer.apply_chat_template(
109
+ messages,
110
+ add_generation_prompt=True,
111
+ return_tensors="pt"
112
+ ).to(gen_model.device)
113
+
114
+ terminators = [
115
+ tokenizer.eos_token_id,
116
+ tokenizer.convert_tokens_to_ids("<|eot_id|>")
117
+ ]
118
+
119
+ outputs = gen_model.generate(
120
+ input_ids,
121
+ max_new_tokens=10,
122
+ eos_token_id=terminators,
123
+ do_sample=True,
124
+ temperature=0.2,
125
+ top_p=1
126
+ )
127
+
128
+ response = outputs[0][input_ids.shape[-1]:]
129
+ return tokenizer.decode(response, skip_special_tokens=True)
130
+
131
+
132
+ def get_llama3_instruct_chat_response(gen_model, tokenizer, gen_model_checkpoint, text, seed):
133
+ messages = [
134
+ {"role": "user", "content": text},
135
+ ]
136
+
137
+ input_ids = tokenizer.apply_chat_template(
138
+ messages,
139
+ add_generation_prompt=True,
140
+ return_tensors="pt"
141
+ ).to(gen_model.device)
142
+
143
+ terminators = [
144
+ tokenizer.eos_token_id,
145
+ tokenizer.convert_tokens_to_ids("<|eot_id|>")
146
+ ]
147
+
148
+ outputs = gen_model.generate(
149
+ input_ids,
150
+ max_new_tokens=10,
151
+ eos_token_id=terminators,
152
+ do_sample=True,
153
+ temperature=0.2,
154
+ top_p=1
155
+ )
156
+
157
+ response = outputs[0][input_ids.shape[-1]:]
158
+ return tokenizer.decode(response, skip_special_tokens=True)
159
+
160
+
161
+ def get_openai_chat_response(gen_model, gen_model_checkpoint, text, seed):
162
+ messages = [
163
+ {
164
+ "role": "user",
165
+ "content": text
166
+ }
167
+ ]
168
+ response = gen_model.chat.completions.create(
169
+ model=gen_model_checkpoint,
170
+ messages=messages,
171
+ temperature=0,
172
+ max_tokens=64,
173
+ top_p=1,
174
+ seed=seed
175
+ )
176
+ return response.choices[0].message.content
177
+
178
+
179
+ def load_model(gen_model_checkpoint, load_in_8bit=False):
180
+ gen_model = None
181
+ tokenizer = None
182
+
183
+ if "mistralai/Mistral-7B-Instruct-v0.3" in gen_model_checkpoint or "meta-llama/Meta-Llama-3-8B-Instruct" in gen_model_checkpoint or "google/gemma-1.1-7b-it" in gen_model_checkpoint:
184
+ if load_in_8bit:
185
+ gen_model = AutoModelForCausalLM.from_pretrained(gen_model_checkpoint, token=HF_TOKEN, device_map="auto",
186
+ load_in_8bit=True)
187
+ tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, token=HF_TOKEN, device_map="auto",
188
+ load_in_8bit=True)
189
+ else:
190
+ gen_model = AutoModelForCausalLM.from_pretrained(gen_model_checkpoint, token=HF_TOKEN)
191
+ tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, token=HF_TOKEN)
192
+ elif "CohereForAI/aya-101" in gen_model_checkpoint or "bigscience/mt0" in gen_model_checkpoint:
193
+ if load_in_8bit:
194
+ gen_model = AutoModelForSeq2SeqLM.from_pretrained(gen_model_checkpoint, token=HF_TOKEN, device_map="auto",
195
+ load_in_8bit=True)
196
+ tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, token=HF_TOKEN, device_map="auto",
197
+ load_in_8bit=True)
198
+ else:
199
+ gen_model = AutoModelForSeq2SeqLM.from_pretrained(gen_model_checkpoint, token=HF_TOKEN)
200
+ tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, token=HF_TOKEN)
201
+ elif "facebook/xglm" in gen_model_checkpoint or "bigscience/bloomz" in gen_model_checkpoint or "aya-23-8B" in args.gen_model_checkpoint:
202
+ if load_in_8bit:
203
+ gen_model = AutoModelForCausalLM.from_pretrained(gen_model_checkpoint, token=HF_TOKEN, device_map="auto",
204
+ load_in_8bit=True)
205
+ tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, token=HF_TOKEN, device_map="auto",
206
+ load_in_8bit=True)
207
+ else:
208
+ gen_model = AutoModelForCausalLM.from_pretrained(gen_model_checkpoint, token=HF_TOKEN)
209
+ tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, token=HF_TOKEN)
210
+ elif "gpt-3.5-turbo" in gen_model_checkpoint or "gpt-4" in gen_model_checkpoint:
211
+ gen_model = OpenAI(api_key=OPENAI_TOKEN)
212
+ elif "command-r" in gen_model_checkpoint:
213
+ gen_model = cohere.Client(COHERE_TOKEN)
214
+
215
+ return gen_model, tokenizer
216
+
agents/chatgpt.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ import os
3
+ import re
4
+
5
+ #%%
6
+ import os
7
+ import json
8
+ import torch
9
+ import numpy as np
10
+ from pathlib import Path
11
+ from transformers import set_seed
12
+ from textgames import GAME_NAMES, LEVEL_IDS, game_filename
13
+ from agents import run_with_agent
14
+
15
+ #%%
16
+ def set_all_seed(seed=42):
17
+ set_seed(seed)
18
+ np.random.seed(seed)
19
+ torch.manual_seed(seed)
20
+ torch.cuda.manual_seed_all(seed)
21
+
22
+
23
+ #%%
24
+ def _getenv_as_int(attr, default=None):
25
+ ret = os.getenv(attr, default)
26
+ return None if ret is None else int(ret)
27
+
28
+
29
+ GAME_ST, GAME_ED = _getenv_as_int("TG_GAME_ST", None), _getenv_as_int("TG_GAME_ED", None)
30
+ LVL_ST, LVL_ED = _getenv_as_int("TG_LEVEL_ST", None), _getenv_as_int("TG_LEVEL_ED", '3')
31
+ SID_ST, SID_ED = _getenv_as_int("TG_SID_ST", None), _getenv_as_int("TG_SID_ED", None)
32
+ N_TURNS = _getenv_as_int("TG_N_TURNS", 1)
33
+ ONE_SHOT = bool(int(os.getenv("TG_ONESHOT", "0")))
34
+ GPT_MODEL = os.getenv("TG_GPT_MODEL", "")
35
+ # MAX_NEW_TOKENS = _getenv_as_int("TG_MAX_NEW_TOKENS", 12000)
36
+
37
+
38
+ #%%
39
+ def preload_responses():
40
+ responses_all = dict()
41
+ for _turn in range(1, N_TURNS+1):
42
+ fp = os.getenv(
43
+ f"TG_GPT_OUTPUT_FILE_TURN_{_turn}",
44
+ (f"model_outputs/__runs__/chatgpt_4o_mini_results/raw/batch_output_chatgpt-{GPT_MODEL}_turn{_turn}"
45
+ f"{'.1s' if ONE_SHOT else '.zs'}.jsonl")
46
+ )
47
+ if not Path(fp).exists():
48
+ if _turn < N_TURNS:
49
+ print(f" batch_output turn {_turn} is not available. path: \"{fp}\"")
50
+ break
51
+ with open(fp, "r", encoding="utf8") as i:
52
+ data = [json.loads(line) for line in i]
53
+ for d in data:
54
+ sid, g = d['custom_id'].rsplit('-', 2)[-2:]
55
+ msg = d['response']['body']['choices'][0]['message']
56
+ responses_all.setdefault((g, _turn), dict())[sid] = msg['content']
57
+ responses_all[g, _turn][sid] = msg['content']
58
+ # assert msg['role'] == 'assistant'
59
+ # assert msg['refusal'] is None
60
+ # assert sum(len(v) for v in responses_all.values()) == 24000
61
+ return responses_all
62
+ RESPONSES_ALL = preload_responses()
63
+ print(f"len(RESPONSES_ALL) = {len(RESPONSES_ALL)}")
64
+
65
+
66
+ #%%
67
+ def gpt_postproc(response_txt_batch, *args, **kwargs):
68
+ response_txt_batch = [response_txt_batch]
69
+ ret = []
70
+ for response_txt in response_txt_batch:
71
+ if response_txt is None:
72
+ ret.append(response_txt)
73
+ continue
74
+ cur = None
75
+ for pat in [
76
+ re.compile(r'^```\n?([^`]*)\n?```'),
77
+ # re.compile(r'\*\*\"?([^\"*]*)\"?\*\*'),
78
+ re.compile(r'((.|\n)*)\n\nExplanation:\n'),
79
+ ]:
80
+ match = pat.search(response_txt)
81
+ if match:
82
+ cur = match.group(1).strip()
83
+ # .replace(" ", "")
84
+ break
85
+ ret.append(cur if cur else response_txt)
86
+ return ret[0]
87
+
88
+
89
+ #%%
90
+ def get_gpt_response(texts, game_name, difficulty_level, turn, *args, **kwargs):
91
+ # global model, tokenizer
92
+ sid = kwargs['sid'] # sid must be fed as params
93
+ messages = [
94
+ ({"role": "user", "content": text}
95
+ if i % 2 == 0 else
96
+ {"role": "assistant", "content": text})
97
+ for i, text in enumerate(texts)
98
+ ]
99
+
100
+ response = None
101
+ responses_all = RESPONSES_ALL.get((f"{game_filename(game_name)}_{difficulty_level}", turn), {})
102
+ if responses_all:
103
+ response = responses_all[sid]
104
+ elif fp_next := os.getenv("TG_GPT_NEXTTURN_OUTPUT_FILE", None):
105
+ with open(fp_next, "a", encoding="utf8") as o:
106
+ o.write(json.dumps({
107
+ 'custom_id': f"{sid}-{game_filename(game_name)}_{difficulty_level}",
108
+ "method": "POST", "url": "/v1/chat/completions",
109
+ "body": {
110
+ "model": "gpt-4o-mini-2024-07-18",
111
+ "max_completion_tokens": 200,
112
+ # "messages": [],
113
+ 'messages': messages,
114
+ "seed": 42,
115
+ "temperature": 0,
116
+ }
117
+ }))
118
+ o.write("\n")
119
+
120
+ return response
121
+
122
+
123
+ #%%
124
+ if __name__ == "__main__":
125
+ fp_out = (f"model_outputs/__runs__/chatgpt_4o_mini_results/process/results_chatgpt-{GPT_MODEL}"
126
+ f"{'.1s' if ONE_SHOT else '.zs'}"
127
+ f"{'' if GAME_ST is None else f'.{GAME_ST}'}"
128
+ f"{'' if LVL_ST is None else f'.{LVL_ST}'}"
129
+ f".jsonl")
130
+
131
+ set_all_seed()
132
+
133
+ run_with_agent(
134
+ fp_out,
135
+ get_gpt_response,
136
+ gpt_postproc,
137
+ n_turns=N_TURNS,
138
+ game_names_list=GAME_NAMES[GAME_ST:GAME_ED],
139
+ level_ids_list=LEVEL_IDS[LVL_ST:LVL_ED],
140
+ sid_indices=(list(map(lambda r: f"session_{r:04}", range(SID_ST or 0, SID_ED or 1000)))
141
+ if SID_ST or SID_ED else None),
142
+ prepend_example=ONE_SHOT,
143
+ # remove_if_output_file_exist=False,
144
+ assistant_uses_raw_response=False,
145
+ )
agents/dsr1_distill.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ import os
3
+ import re
4
+
5
+ #%%
6
+ import torch
7
+ import numpy as np
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
9
+ from textgames import THE_GAMES, GAME_NAMES, LEVEL_IDS
10
+ from agents import run_with_agent
11
+
12
+ #%%
13
+ def set_all_seed(seed=42):
14
+ set_seed(seed)
15
+ np.random.seed(seed)
16
+ torch.manual_seed(seed)
17
+ torch.cuda.manual_seed_all(seed)
18
+
19
+
20
+ #%%
21
+ def _getenv_as_int(attr, default=None):
22
+ ret = os.getenv(attr, default)
23
+ return None if ret is None else int(ret)
24
+
25
+
26
+ GAME_ST, GAME_ED = _getenv_as_int("TG_GAME_ST", None), _getenv_as_int("TG_GAME_ED", None)
27
+ LVL_ST, LVL_ED = _getenv_as_int("TG_LEVEL_ST", None), _getenv_as_int("TG_LEVEL_ED", '3')
28
+ SID_ST, SID_ED = _getenv_as_int("TG_SID_ST", None), _getenv_as_int("TG_SID_ED", None)
29
+ N_TURNS = _getenv_as_int("TG_N_TURNS", 1)
30
+ ONE_SHOT = bool(int(os.getenv("TG_ONESHOT", "0")))
31
+ MAX_NEW_TOKENS = _getenv_as_int("TG_MAX_NEW_TOKENS", 12000)
32
+ DSR1_SIZE = os.getenv("TG_DSR1_SIZE", "14") # {1.5, 7, 8, 14, 32, 70}
33
+ DSR1_NAME = {
34
+ "1.5": "Qwen-1.5",
35
+ "7": "Qwen-7",
36
+ "8": "Llama-8",
37
+ "14": "Qwen-14",
38
+ "32": "Qwen-32",
39
+ "70": "Llama-70",
40
+ }
41
+
42
+
43
+ #%%
44
+ def dsr1_postproc(response_txt_batch, *args, **kwargs):
45
+ response_txt_batch = [response_txt_batch]
46
+ ret = []
47
+ for response_txt in response_txt_batch:
48
+ _match = None
49
+ for pat in [
50
+ re.compile(r'\\boxed\{([\s\S]*)}'),
51
+ re.compile(r'</think>\n([\s\S]*)$'),
52
+ re.compile(r'^```\n?([^`]*)\n?```'),
53
+ ]:
54
+ matches = pat.search(response_txt)
55
+ if matches:
56
+ _match = matches.group(1).strip()
57
+ break
58
+ if _match is not None:
59
+ ret.append(_match)
60
+ else:
61
+ ret.append(response_txt[:256].strip() if response_txt else "")
62
+ return ret[0]
63
+
64
+
65
+ #%%
66
+ def get_dsr1_response(texts_batch, *args, **kwargs):
67
+ # global model, tokenizer
68
+ texts_batch = [texts_batch]
69
+ for texts in texts_batch:
70
+ if len(texts) > 1 and texts[1].startswith('Correct guess.'):
71
+ texts[1] = f"\\boxed{{{texts[1]}}}"
72
+ messages = [
73
+ [
74
+ {"role": "user",
75
+ "content": f"{text}\nPlease reason step by step, and put your final answer within \\boxed{{}} as plain text."}
76
+ if i % 2 == 0 else
77
+ {"role": "assistant", "content": {text}}
78
+ for i, text in enumerate(texts)
79
+ ]
80
+ for texts in texts_batch
81
+ ]
82
+ text_inputs = tokenizer.apply_chat_template(
83
+ messages,
84
+ tokenize=False,
85
+ add_generation_prompt=True
86
+ )
87
+ model_inputs = tokenizer(text_inputs, return_tensors="pt", add_special_tokens=False).to(model.device)
88
+ output_ids = model.generate(
89
+ **model_inputs,
90
+ max_new_tokens=MAX_NEW_TOKENS,
91
+ do_sample=False,
92
+ pad_token_id=tokenizer.eos_token_id,
93
+ )
94
+ generated_ids = [
95
+ _output_ids[len(input_ids):] for input_ids, _output_ids in zip(model_inputs.input_ids, output_ids)
96
+ ]
97
+ response = [r.strip() for r in tokenizer.batch_decode(generated_ids, skip_special_tokens=True)]
98
+ return response[0]
99
+
100
+
101
+ #%%
102
+ # response = get_dsr1_response(texts)
103
+ # print(dsr1_postproc(response))
104
+
105
+
106
+ #%%
107
+ if __name__ == "__main__":
108
+ fp_out = (f"model_outputs/__runs__/results_deepseek-r1-distill-{DSR1_SIZE}b"
109
+ f"{'.1s' if ONE_SHOT else '.zs'}"
110
+ f"{'' if GAME_ST is None else f'.{GAME_ST}'}"
111
+ f"{'' if LVL_ST is None else f'.{LVL_ST}'}"
112
+ f".jsonl")
113
+
114
+ set_all_seed()
115
+ model_name = f"deepseek-ai/DeepSeek-R1-Distill-{DSR1_NAME[DSR1_SIZE]}B"
116
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
117
+ model = AutoModelForCausalLM.from_pretrained(
118
+ model_name,
119
+ device_map="auto",
120
+ torch_dtype="auto",
121
+ )
122
+ model.generation_config.temperature = None
123
+ model.generation_config.top_k = None
124
+ model.generation_config.top_p = None
125
+
126
+ run_with_agent(
127
+ fp_out,
128
+ get_dsr1_response,
129
+ dsr1_postproc,
130
+ n_turns=N_TURNS,
131
+ game_names_list=GAME_NAMES[GAME_ST:GAME_ED],
132
+ level_ids_list=LEVEL_IDS[LVL_ST:LVL_ED],
133
+ sid_indices=(list(map(lambda r: f"session_{r:04}", range(SID_ST or 0, SID_ED or 1000)))
134
+ if SID_ST or SID_ED else None),
135
+ prepend_example=ONE_SHOT,
136
+ # remove_if_output_file_exist=False,
137
+ assistant_uses_raw_response=False,
138
+ )
agents/gemma_2_9b_it.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ import os
3
+ import re
4
+
5
+ #%%
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+ from textgames import THE_GAMES, GAME_NAMES, LEVEL_IDS
8
+ from agents import run_with_agent
9
+
10
+
11
+ #%%
12
+ def _getenv_as_int(attr, default=None):
13
+ ret = os.getenv(attr, default)
14
+ return None if ret is None else int(ret)
15
+
16
+
17
+ GAME_ST, GAME_ED = _getenv_as_int("TG_GAME_ST", None), _getenv_as_int("TG_GAME_ED", None)
18
+ LVL_ST, LVL_ED = _getenv_as_int("TG_LEVEL_ST", None), _getenv_as_int("TG_LEVEL_ED", '3')
19
+ SID_ST, SID_ED = _getenv_as_int("TG_SID_ST", None), _getenv_as_int("TG_SID_ED", None)
20
+ N_TURNS = _getenv_as_int("TG_N_TURNS", 3)
21
+ ONE_SHOT = bool(int(os.getenv("TG_ONESHOT", "0")))
22
+ GEMMA_SIZE = int(os.getenv("TG_GEMMA_SIZE", "9")) # {3, 9, 27}
23
+
24
+
25
+ #%%
26
+ def gemma_postproc(response_txt, game_name, difficulty_level, *args, **kwargs):
27
+ # if game_name in [THE_GAMES[i] for i in ["1", "7"]]: # crossword
28
+ pat = re.compile(r'^```\n?([^`]*)\n?```')
29
+ match = pat.search(response_txt)
30
+ if match:
31
+ return match.group(1).strip().replace(" ", "")
32
+
33
+ # elif game_name == THE_GAMES["6"]: # anagram
34
+ pat = re.compile(r'\*\*\"?([^\"*]*)\"?\*\*')
35
+ match = pat.search(response_txt)
36
+ if match:
37
+ return match.group(1).strip()
38
+
39
+ return response_txt or ""
40
+
41
+
42
+ #%%
43
+ def get_gemma_response(texts, game_name, difficulty_level, turn, *args, **kwargs):
44
+ # global gen_model, tokenizer
45
+ messages = [
46
+ {"role": ("model" if i % 2 else "user"), "content": text}
47
+ for i, text in enumerate(texts)
48
+ ]
49
+
50
+ input_ids = tokenizer.apply_chat_template(
51
+ messages,
52
+ add_generation_prompt=True,
53
+ return_tensors="pt"
54
+ ).to(gen_model.device)
55
+
56
+ terminators = [
57
+ tokenizer.eos_token_id,
58
+ tokenizer.convert_tokens_to_ids("<|eot_id|>")
59
+ ]
60
+
61
+ gen_model.generation_config.temperature = None
62
+ outputs = gen_model.generate(
63
+ input_ids,
64
+ max_new_tokens=100,
65
+ eos_token_id=terminators,
66
+ do_sample=False,
67
+ # temperature=.0,
68
+ # top_p=1,
69
+ )
70
+
71
+ response = outputs[0][input_ids.shape[-1]:]
72
+ return tokenizer.decode(response, skip_special_tokens=True).strip()
73
+
74
+
75
+ #%%
76
+ if __name__ == "__main__":
77
+ fp_out = (f"model_outputs/results_gemma-2-{GEMMA_SIZE}b-it"
78
+ f"{'.1s' if ONE_SHOT else '.zs'}"
79
+ f"{'' if GAME_ST is None else f'.{GAME_ST}'}"
80
+ f"{'' if LVL_ST is None else f'.{LVL_ST}'}"
81
+ f".jsonl")
82
+ gen_model_checkpoint = f"google/gemma-2-{GEMMA_SIZE}b-it"
83
+
84
+ quantize = True
85
+ _kwargs = {
86
+ "device_map": "auto",
87
+ } if quantize else {}
88
+
89
+ gen_model = AutoModelForCausalLM.from_pretrained(gen_model_checkpoint, **_kwargs)
90
+ tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, **_kwargs)
91
+ print(f" > model.dtype: {gen_model.dtype}")
92
+
93
+ run_with_agent(
94
+ fp_out,
95
+ get_gemma_response,
96
+ gemma_postproc,
97
+ n_turns=N_TURNS,
98
+ game_names_list=GAME_NAMES[GAME_ST:GAME_ED],
99
+ level_ids_list=LEVEL_IDS[LVL_ST:LVL_ED],
100
+ sid_indices=(list(map(lambda r: f"session_{r:04}", range(SID_ST or 0, SID_ED or 1000)))
101
+ if SID_ST or SID_ED else None),
102
+ prepend_example=ONE_SHOT,
103
+ # remove_if_output_file_exist=False,
104
+ )
agents/llama3.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ import os
3
+ import re
4
+
5
+ #%%
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+ from textgames import THE_GAMES, GAME_NAMES, LEVEL_IDS
8
+ from agents import run_with_agent
9
+
10
+
11
+ #%%
12
+ def _getenv_as_int(attr, default=None):
13
+ ret = os.getenv(attr, default)
14
+ return None if ret is None else int(ret)
15
+
16
+
17
+ GAME_ST, GAME_ED = _getenv_as_int("TG_GAME_ST", None), _getenv_as_int("TG_GAME_ED", None)
18
+ LVL_ST, LVL_ED = _getenv_as_int("TG_LEVEL_ST", None), _getenv_as_int("TG_LEVEL_ED", '3')
19
+ SID_ST, SID_ED = _getenv_as_int("TG_SID_ST", None), _getenv_as_int("TG_SID_ED", None)
20
+ N_TURNS = _getenv_as_int("TG_N_TURNS", 3)
21
+ ONE_SHOT = bool(int(os.getenv("TG_ONESHOT", "0")))
22
+ LLAMA_SIZE = os.getenv("TG_LLAMA_SIZE", "1-8")
23
+
24
+
25
+ #%%
26
+ def llama_postproc(response_txt, *args, **kwargs):
27
+ # # if game_name in [THE_GAMES[i] for i in ["1", "7"]]: # crossword
28
+ # pat = re.compile(r'^```\n?([^`]*)\n?```')
29
+ # match = pat.search(response_txt)
30
+ # if match:
31
+ # return match.group(1).strip().replace(" ", "")
32
+ #
33
+ # # elif game_name == THE_GAMES["6"]: # anagram
34
+ # pat = re.compile(r'\*\*\"?([^\"*]*)\"?\*\*')
35
+ # match = pat.search(response_txt)
36
+ # if match:
37
+ # return match.group(1).strip()
38
+ return response_txt or ""
39
+
40
+
41
+ #%%
42
+ def get_llama_response(texts, *args, **kwargs):
43
+ # global model, tokenizer
44
+
45
+ messages = [
46
+ # {"role": "system", "content": "You are a bot that responds to weather queries."},
47
+ *[{"role": ("assistant" if i % 2 else "user"), "content": text} for i, text in enumerate(texts)]
48
+ ]
49
+
50
+ text_inputs = tokenizer.apply_chat_template(
51
+ messages,
52
+ tokenize=False,
53
+ add_generation_prompt=True,
54
+ )
55
+ model_inputs = tokenizer([text_inputs], return_tensors="pt").to(model.device)
56
+
57
+ model.generation_config.do_sample = False
58
+ model.generation_config.temperature = None
59
+ model.generation_config.top_k = None
60
+ model.generation_config.top_p = None
61
+ generated_ids = model.generate(
62
+ **model_inputs,
63
+ max_new_tokens=128,
64
+ do_sample=False,
65
+ pad_token_id=tokenizer.eos_token_id,
66
+ )
67
+ generated_ids = [
68
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
69
+ ]
70
+
71
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
72
+ return response.strip()
73
+
74
+
75
+ #%%
76
+ if __name__ == "__main__":
77
+ fp_out = (f"model_outputs/__runs__/results_llama-3.{LLAMA_SIZE}b-instruct"
78
+ f"{'.1s' if ONE_SHOT else '.zs'}"
79
+ f"{'' if GAME_ST is None else f'.{GAME_ST}'}"
80
+ f"{'' if LVL_ST is None else f'.{LVL_ST}'}"
81
+ f".jsonl")
82
+
83
+ model_name = f"meta-llama/Llama-3.{LLAMA_SIZE}B-Instruct"
84
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
85
+ model = AutoModelForCausalLM.from_pretrained(
86
+ model_name,
87
+ device_map="auto",
88
+ torch_dtype="bfloat16",
89
+ )
90
+
91
+ run_with_agent(
92
+ fp_out,
93
+ get_llama_response,
94
+ llama_postproc,
95
+ n_turns=N_TURNS,
96
+ game_names_list=GAME_NAMES[GAME_ST:GAME_ED],
97
+ level_ids_list=LEVEL_IDS[LVL_ST:LVL_ED],
98
+ sid_indices=(list(map(lambda r: f"session_{r:04}", range(SID_ST or 0, SID_ED or 1000)))
99
+ if SID_ST or SID_ED else None),
100
+ prepend_example=ONE_SHOT,
101
+ # remove_if_output_file_exist=False,
102
+ )
agents/qwen2_5_7b_instruct.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ import os
3
+ import re
4
+
5
+ #%%
6
+ import torch
7
+ import numpy as np
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
9
+ from textgames import GAME_NAMES, LEVEL_IDS
10
+ from agents import run_with_agent
11
+
12
+
13
+ #%%
14
+ def set_all_seed(seed=42):
15
+ set_seed(seed)
16
+ np.random.seed(seed)
17
+ torch.manual_seed(seed)
18
+ torch.cuda.manual_seed_all(seed)
19
+
20
+
21
+ #%%
22
+ def _getenv_as_int(attr, default=None):
23
+ ret = os.getenv(attr, default)
24
+ return None if ret is None else int(ret)
25
+
26
+
27
+ GAME_ST, GAME_ED = _getenv_as_int("TG_GAME_ST", None), _getenv_as_int("TG_GAME_ED", None)
28
+ LVL_ST, LVL_ED = _getenv_as_int("TG_LEVEL_ST", None), _getenv_as_int("TG_LEVEL_ED", '3')
29
+ SID_ST, SID_ED = _getenv_as_int("TG_SID_ST", None), _getenv_as_int("TG_SID_ED", None)
30
+ N_TURNS = _getenv_as_int("TG_N_TURNS", 3)
31
+ ONE_SHOT = bool(int(os.getenv("TG_ONESHOT", "0")))
32
+ QWEN_SIZE = int(os.getenv("TG_QWEN_SIZE", "32")) # {3, 7, 14, 32, 72} unsupported: {0.5, 1.5}
33
+
34
+
35
+ #%%
36
+ def qwen_postproc(response_txt, game_name, difficulty_level, *args, **kwargs):
37
+ # # if game_name in [THE_GAMES[i] for i in ["1", "7"]]: # crossword
38
+ # pat = re.compile(r'^```\n?([^`]*)\n?```')
39
+ # match = pat.search(response_txt)
40
+ # if match:
41
+ # return match.group(1).strip().replace(" ", "")
42
+ #
43
+ # # elif game_name == THE_GAMES["6"]: # anagram
44
+ # pat = re.compile(r'\*\*\"?([^\"*]*)\"?\*\*')
45
+ # match = pat.search(response_txt)
46
+ # if match:
47
+ # return match.group(1).strip()
48
+ return response_txt or ""
49
+
50
+
51
+ #%%
52
+ def get_qwen_response(texts_batch, game_name, difficulty_level, turn, *args, **kwargs):
53
+ # global model, tokenizer
54
+ texts_batch = [texts_batch] # currently we do not support batch
55
+ messages = [[
56
+ {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
57
+ *[{"role": ("assistant" if i % 2 else "user"), "content": text} for i, text in enumerate(texts)]
58
+ ] for texts in texts_batch ]
59
+
60
+ text_inputs = tokenizer.apply_chat_template(
61
+ messages,
62
+ tokenize=False,
63
+ add_generation_prompt=True
64
+ )
65
+ model_inputs = tokenizer([text_inputs], return_tensors="pt").to(model.device)
66
+
67
+ model.generation_config.temperature = None
68
+ model.generation_config.top_k = None
69
+ model.generation_config.top_p = None
70
+ generated_ids = model.generate(
71
+ **model_inputs,
72
+ max_new_tokens=128,
73
+ do_sample=False,
74
+ )
75
+ generated_ids = [
76
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
77
+ ]
78
+
79
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
80
+ return response.strip()
81
+
82
+
83
+ #%%
84
+ if __name__ == "__main__":
85
+ fp_out = (f"model_outputs/__runs__/results_qwen2-5-{QWEN_SIZE}b-instruct"
86
+ f"{'.1s' if ONE_SHOT else '.zs'}"
87
+ f"{'' if GAME_ST is None else f'.{GAME_ST}'}"
88
+ f"{'' if LVL_ST is None else f'.{LVL_ST}'}"
89
+ f".jsonl")
90
+
91
+ set_all_seed()
92
+ model_name = f"Qwen/Qwen2.5-{QWEN_SIZE}B-Instruct"
93
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
94
+ model = AutoModelForCausalLM.from_pretrained(
95
+ model_name,
96
+ device_map="auto",
97
+ torch_dtype="auto",
98
+ )
99
+ print(f" > model.dtype: {model.dtype}")
100
+
101
+ run_with_agent(
102
+ fp_out,
103
+ get_qwen_response,
104
+ qwen_postproc,
105
+ n_turns=N_TURNS,
106
+ game_names_list=GAME_NAMES[GAME_ST:GAME_ED],
107
+ level_ids_list=LEVEL_IDS[LVL_ST:LVL_ED],
108
+ sid_indices=(list(map(lambda r: f"session_{r:04}", range(SID_ST or 0, SID_ED or 1000)))
109
+ if SID_ST or SID_ED else None),
110
+ prepend_example=ONE_SHOT,
111
+ # remove_if_output_file_exist=False,
112
+ )
agents/qwen2_5_math.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ import os
3
+ import re
4
+
5
+ #%%
6
+ import torch
7
+ import numpy as np
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed, BitsAndBytesConfig
9
+ from textgames import GAME_NAMES, LEVEL_IDS
10
+ from agents import run_with_agent
11
+
12
+
13
+ #%%
14
+ def set_all_seed(seed=42):
15
+ set_seed(seed)
16
+ np.random.seed(seed)
17
+ torch.manual_seed(seed)
18
+ torch.cuda.manual_seed_all(seed)
19
+
20
+
21
+ #%%
22
+ def _getenv_as_int(attr, default=None):
23
+ ret = os.getenv(attr, default)
24
+ return None if ret is None else int(ret)
25
+
26
+
27
+ GAME_ST, GAME_ED = _getenv_as_int("TG_GAME_ST", None), _getenv_as_int("TG_GAME_ED", None)
28
+ LVL_ST, LVL_ED = _getenv_as_int("TG_LEVEL_ST", None), _getenv_as_int("TG_LEVEL_ED", '3')
29
+ SID_ST, SID_ED = _getenv_as_int("TG_SID_ST", None), _getenv_as_int("TG_SID_ED", None)
30
+ N_TURNS = _getenv_as_int("TG_N_TURNS", 3)
31
+ ONE_SHOT = bool(int(os.getenv("TG_ONESHOT", "0")))
32
+ # MAX_NEW_TOKENS = _getenv_as_int("TG_MAX_NEW_TOKENS", 4096)
33
+ QWEN_MATH_SIZE = os.getenv("TG_QWEN_MATH_SIZE", "7") # {1.5, 7, 72}
34
+ QUANTIZE = _getenv_as_int("TG_QUANTIZE", 4)
35
+
36
+
37
+ #%%
38
+ def qwenmath_postproc(response_txt_batch, *args, **kwargs):
39
+ response_txt_batch = [response_txt_batch]
40
+ ret = []
41
+ for response_txt in response_txt_batch:
42
+ _match = None
43
+ for pat in [
44
+ re.compile(r'\\boxed\{([\s\S]*)}'),
45
+ re.compile(r'^```\n?([^`]*)\n?```'),
46
+ # re.compile(r'</think>\n([\s\S]*)$'),
47
+ ]:
48
+ matches = pat.search(response_txt)
49
+ if matches:
50
+ _match = matches.group(1).strip()
51
+ break
52
+ if _match is not None:
53
+ ret.append(_match)
54
+ else:
55
+ ret.append(response_txt if response_txt else "")
56
+ return ret[0]
57
+
58
+
59
+ #%%
60
+ def get_qwenmath_response(texts_batch, *args, **kwargs):
61
+ # global model, tokenizer
62
+ texts_batch = [texts_batch]
63
+ for texts in texts_batch:
64
+ if (len(texts) > 1) and texts[2].startswith('Correct guess.'): # assert len(texts) % 2 == 1
65
+ texts[1] = f"\\boxed{{{texts[1]}}}"
66
+ messages = [
67
+ [
68
+ {"role": "system", "content": "Please reason step by step, and put your final answer within \\boxed{} as plain text."},
69
+ *[{"role": ("user" if i % 2 == 0 else "assistant"), "content": text} for i, text in enumerate(texts)],
70
+ ]
71
+ for texts in texts_batch
72
+ ]
73
+ # print(f"\n{messages[0]}", end="\n=====\n\n")
74
+
75
+ text_inputs = tokenizer.apply_chat_template(
76
+ messages,
77
+ tokenize=False,
78
+ add_generation_prompt=True
79
+ )
80
+ model_inputs = tokenizer(text_inputs, return_tensors="pt", add_special_tokens=False).to(model.device)
81
+
82
+ generated_ids = model.generate(
83
+ **model_inputs,
84
+ max_new_tokens=512,
85
+ do_sample=False,
86
+ )
87
+ generated_ids = [
88
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
89
+ ]
90
+
91
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
92
+ return response.strip()
93
+
94
+
95
+ #%%
96
+ if __name__ == "__main__":
97
+ fp_out = (f"model_outputs/__runs__/results_qwen2-5-math-{QWEN_MATH_SIZE}b-instruct_{QUANTIZE}bit"
98
+ f"{'.1s' if ONE_SHOT else '.zs'}"
99
+ f"{'' if GAME_ST is None else f'.{GAME_ST}'}"
100
+ f"{'' if LVL_ST is None else f'.{LVL_ST}'}"
101
+ f".jsonl")
102
+
103
+ set_all_seed()
104
+ if QWEN_MATH_SIZE in ['72'] and QUANTIZE < 16:
105
+ _additional_kwargs = {
106
+ "quantization_config": (
107
+ BitsAndBytesConfig(load_in_8bit=True)
108
+ if QUANTIZE == 8 else
109
+ BitsAndBytesConfig(load_in_4bit=True)
110
+ ),
111
+ "low_cpu_mem_usage": True,
112
+ }
113
+ else:
114
+ _additional_kwargs = {"device_map": "auto"}
115
+
116
+ model_name = f"Qwen/Qwen2.5-Math-{QWEN_MATH_SIZE}B-Instruct"
117
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
118
+ model = AutoModelForCausalLM.from_pretrained(
119
+ model_name,
120
+ torch_dtype="auto",
121
+ **_additional_kwargs,
122
+ )
123
+ print(f" > model.dtype: {model.dtype}")
124
+
125
+ run_with_agent(
126
+ fp_out,
127
+ get_qwenmath_response,
128
+ qwenmath_postproc,
129
+ n_turns=N_TURNS,
130
+ game_names_list=GAME_NAMES[GAME_ST:GAME_ED],
131
+ level_ids_list=LEVEL_IDS[LVL_ST:LVL_ED],
132
+ sid_indices=(list(map(lambda r: f"session_{r:04}", range(SID_ST or 0, SID_ED or 1000)))
133
+ if SID_ST or SID_ED else None),
134
+ prepend_example=ONE_SHOT,
135
+ # remove_if_output_file_exist=False,
136
+ assistant_uses_raw_response=True,
137
+ )
agents/runner.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ import os
3
+ import json
4
+
5
+ from textgames import GAME_NAMES, LEVEL_IDS, game_filename, _game_class_from_name
6
+
7
+ from tqdm import tqdm
8
+ from itertools import product
9
+ from pathlib import Path
10
+ from typing import Union, Callable
11
+
12
+
13
+ def response_postprocess(response_txt, game_name, difficulty_level):
14
+ return response_txt or ""
15
+
16
+
17
+ def run_with_agent(fp_out: Union[str, Path],
18
+ get_response: Callable,
19
+ get_postprocess: Callable = response_postprocess,
20
+ n_turns=3,
21
+ game_names_list=GAME_NAMES,
22
+ level_ids_list=LEVEL_IDS[:3],
23
+ sid_indices=None, # sid_index_range=range(0, 1000),
24
+ remove_if_output_file_exist=True,
25
+ prepend_example=False,
26
+ assistant_uses_raw_response=True,
27
+ ) -> None:
28
+ os.makedirs(os.path.dirname(os.path.abspath(fp_out)), exist_ok=True)
29
+ print(fp_out)
30
+ if remove_if_output_file_exist:
31
+ with open(fp_out, "wb"):
32
+ pass
33
+
34
+ for game_name, difficulty_level in product(game_names_list, level_ids_list):
35
+ game_str = f"{game_filename(game_name)}_{difficulty_level}"
36
+ game_cls = _game_class_from_name(game_name)
37
+ with open(f"problemsets/{game_str}.json", "r", encoding="utf8") as f:
38
+ sid_prompt_dict = json.load(f)
39
+ if sid_indices is not None:
40
+ sid_prompt_dict = {k: sid_prompt_dict[k] for k in sid_indices}
41
+
42
+ correct_cnt, exception_cnt = 0, 0
43
+ for sid, prompt in tqdm(sid_prompt_dict.items(), desc=game_str, total=len(sid_prompt_dict)):
44
+ cur_game = game_cls()
45
+ cur_game.load_game(prompt)
46
+ if prepend_example:
47
+ texts = [*cur_game.example(), f"Correct guess. Now let's try another example.\n{cur_game.get_prompt()}"]
48
+ else:
49
+ texts = [cur_game.get_prompt()]
50
+ for turn in range(1, n_turns + 1):
51
+ response_raw, response, e = None, None, None
52
+ solved, val_msg = False, None
53
+ try:
54
+ response_raw = get_response(texts, game_name, difficulty_level, turn, sid=sid)
55
+ response = get_postprocess(response_raw, game_name, difficulty_level)
56
+ texts.append(response_raw if assistant_uses_raw_response else response)
57
+ solved, val_msg = (False, None) if response is None else cur_game.validate(response)
58
+ texts.append(
59
+ f"Bad guess (Wrong Answer).\n{val_msg}\nPlease try again and print the answer only."
60
+ if not solved else "Correct guess."
61
+ )
62
+ except Exception as _e:
63
+ e = _e
64
+ # print(e)
65
+ # assert False, {"texts": texts, "response": response_raw,
66
+ # "args": (n_turns, game_names_list, remove_if_output_file_exist, prepend_example, assistant_uses_raw_response)}
67
+ with open(fp_out, "a", encoding="utf8") as o:
68
+ json.dump({
69
+ "game": game_str,
70
+ "session": sid,
71
+ "turn": turn,
72
+ "response": response,
73
+ "solved": solved,
74
+ "val_msg": val_msg,
75
+ "response_raw": response_raw,
76
+ "error": repr(e) if e else e,
77
+ }, o, ensure_ascii=False)
78
+ o.write("\n")
79
+ if solved:
80
+ correct_cnt += 1
81
+ if e:
82
+ exception_cnt += 1
83
+ if solved or e:
84
+ break
85
+
86
+ print(f"{game_filename(game_name)}_-_{difficulty_level}")
87
+ print(f" > Correct: {correct_cnt:>6,} ({correct_cnt / len(sid_prompt_dict):.2%})")
88
+ print(f" > Error : {exception_cnt:>6,} ({exception_cnt / len(sid_prompt_dict):.2%})")
89
+
play_gradio.py CHANGED
@@ -51,7 +51,7 @@ def greet(request: gr.Request):
51
 
52
  #%%
53
  with gr.Blocks(title="TextGames", css=css, delete_cache=(3600, 3600)) as demo:
54
- ((m, logout_btn, solved_games_df, game_radio, level_radio, new_game_btn, render_toggle),
55
  (session_state, is_solved, solved_games, user_state, uid_state),
56
  ) = declare_components(demo, greet)
57
 
@@ -64,7 +64,7 @@ demo.launch(
64
  auth=file_based_auth,
65
  favicon_path=favicon_path if os.path.exists(favicon_path) else None,
66
  share=True,
67
- ssr_mode=False,
68
  )
69
 
70
 
 
51
 
52
  #%%
53
  with gr.Blocks(title="TextGames", css=css, delete_cache=(3600, 3600)) as demo:
54
+ ((m, logout_btn, solved_games_df, game_radio, level_radio, new_game_btn, render_toggle, reset_sid_btn),
55
  (session_state, is_solved, solved_games, user_state, uid_state),
56
  ) = declare_components(demo, greet)
57
 
 
64
  auth=file_based_auth,
65
  favicon_path=favicon_path if os.path.exists(favicon_path) else None,
66
  share=True,
67
+ show_api=False,
68
  )
69
 
70
 
play_helper.py CHANGED
@@ -1,6 +1,7 @@
1
  # %%
2
  import os
3
  import time
 
4
  import pandas as pd
5
  import gradio as gr
6
  import hashlib
@@ -19,19 +20,27 @@ from googleapiclient.discovery import build
19
  from googleapiclient.errors import HttpError
20
  from googleapiclient.http import MediaFileUpload, MediaIoBaseDownload
21
 
 
 
 
22
 
23
  # %%
24
- def declare_components(demo, greet):
25
  with gr.Row():
26
  with gr.Column(scale=1):
27
  m = gr.Markdown("Welcome to TextGames!", elem_id="md-greeting")
28
- logout_btn = gr.Button("Logout", link="/logout", variant='huggingface', size='sm', elem_id="btn-logout")
 
 
 
 
 
29
  with gr.Column(scale=2):
30
- solved_games_df = gr.DataFrame(headers=[g.split('\t', 1)[0] for g in GAME_NAMES], label="Finished Games",
31
- interactive=False, elem_id="df-solved-games")
32
- game_radio = gr.Radio(GAME_NAMES, label="Game", elem_id="radio-game-name")
33
- level_radio = gr.Radio(LEVELS, label="Level", elem_id="radio-level-name")
34
- new_game_btn = gr.Button("Start Game", elem_id="btn-start-game")
35
  render_toggle = gr.Checkbox(False, visible=False, interactive=False)
36
 
37
  # cur_game_start = gr.BrowserState()
@@ -41,9 +50,12 @@ def declare_components(demo, greet):
41
  user_state = gr.State()
42
  uid_state = gr.State()
43
 
 
 
 
44
  session_state.change(
45
- lambda s: session_state_change_fn(s, 2, 0, 2, 0),
46
- [session_state], [game_radio, level_radio, new_game_btn, logout_btn], js=js_remove_input_helper,
47
  )
48
  new_game_btn.click(check_to_start_new_game, [game_radio, level_radio, user_state, uid_state], [session_state])
49
  solved_games.change(solved_games_change_fn, solved_games, solved_games_df)
@@ -54,13 +66,15 @@ def declare_components(demo, greet):
54
  ).then(
55
  lambda: gr.update(interactive=False), None, [new_game_btn],
56
  ).then(
57
- check_played_game, [solved_games, uid_state], [solved_games, solved_games_df]
58
  ).then(
59
- lambda: gr.update(interactive=True), None, [new_game_btn],
 
 
60
  )
61
 
62
  return (
63
- (m, logout_btn, solved_games_df, game_radio, level_radio, new_game_btn, render_toggle),
64
  (session_state, is_solved, solved_games, user_state, uid_state),
65
  )
66
 
@@ -489,7 +503,8 @@ def _is_checksum_same(fp_out, matches=None, mime_type="application/octet-stream"
489
  matches = _files.list(
490
  q=f"'{_folder_id}' in parents and mimeType='{mime_type}' and name = '{fp_out.rsplit('/', 1)[-1]}'",
491
  fields=f"files(name, id, {_cksm_methods_str})",
492
- ).execute()['files']
 
493
  if not os.path.exists(fp_out):
494
  return None, None, matches
495
  with open(fp_out, "rb") as o:
@@ -502,9 +517,9 @@ def _is_checksum_same(fp_out, matches=None, mime_type="application/octet-stream"
502
 
503
 
504
  # %%
505
- def upload_to_drive(fp_out, matches=None, mime_type="application/octet-stream", compare_checksum=True):
506
  if compare_checksum:
507
- same_checksum, _, _ = _is_checksum_same(fp_out, matches, mime_type)
508
  # same_checksum, _, _ = _is_checksum_same(
509
  # fp_out, **{k: v for k, v in [('matches', matches), ('mime_type', mime_type)] if v})
510
  if same_checksum:
@@ -513,7 +528,11 @@ def upload_to_drive(fp_out, matches=None, mime_type="application/octet-stream",
513
  file_metadata = {"name": fn, "parents": [_folder_id]}
514
  media = MediaFileUpload(fp_out)
515
  try:
516
- _files.create(body=file_metadata, media_body=media).execute()
 
 
 
 
517
  except HttpError as error:
518
  msg = f"Failed to upload the file, error: {error}"
519
  print(msg)
@@ -547,7 +566,7 @@ def download_from_drive(fp_out, matches=None, mime_type="application/octet-strea
547
 
548
  # %%
549
  def start_new_game(game_name, level, session_state_component, is_solved_component, solved_games_component,
550
- user=None, show_timer=False, uid=None):
551
  # cur_game_id = GAME_IDS[GAME_NAMES.index(game_name)]
552
  difficulty_level = LEVEL_IDS[LEVELS.index(level)]
553
 
@@ -555,11 +574,16 @@ def start_new_game(game_name, level, session_state_component, is_solved_componen
555
  # elapsed_text = gr.Textbox("N/A", label=f"{game_name}", info=f"{level}", )
556
  # gr.Timer(.3).tick(_calc_time_elapsed, [cur_game_start, elapsed_text, is_solved_component], [elapsed_text])
557
 
558
- fp_out = _get_file_output(game_name, difficulty_level, uid)
 
 
 
559
  cur_game = (
560
  new_game(game_name, difficulty_level)
561
  if user is None else
562
  preload_game(game_name, difficulty_level, user)
 
 
563
  )
564
  cur_game.attach_stats_output_(fp_out)
565
  cur_game.flush_stats_(user_info_to_flush=user)
@@ -616,8 +640,12 @@ def start_new_game(game_name, level, session_state_component, is_solved_componen
616
  js=js_submit)
617
  give_up_checkbox = gr.Checkbox(False, visible=False, interactive=False)
618
  give_up_btn.click(
 
 
619
  lambda x: x, [give_up_checkbox], [give_up_checkbox],
620
  js="(x) => confirm('🥹 Give-up? 💸')"
 
 
621
  )
622
 
623
  def _forfeiting(confirmed, _solved_games):
@@ -640,6 +668,8 @@ def start_new_game(game_name, level, session_state_component, is_solved_componen
640
  def game_is_solved(_is_solved, _session_state, _solved_games, progress=gr.Progress()):
641
  if _is_solved:
642
  if level in LEVELS and level not in _solved_games[game_name]:
 
 
643
  _solved_games[game_name].append(level)
644
  return (
645
  2,
@@ -655,8 +685,16 @@ def start_new_game(game_name, level, session_state_component, is_solved_componen
655
 
656
  def finalize_game(_is_solved):
657
  if _is_solved:
658
- gr.Info("Reporting... Please click the button when available...")
 
 
 
 
 
 
 
659
  upload_to_drive(fp_out)
 
660
  return gr.update(interactive=True)
661
  return gr.update()
662
 
@@ -673,13 +711,14 @@ def start_new_game(game_name, level, session_state_component, is_solved_componen
673
 
674
 
675
  # %%
676
- def check_to_start_new_game(game_name, level, user=None, uid=None):
677
- print(game_name, level)
678
  if game_name is None or level is None:
679
  raise gr.Error("please choose both Game & Level")
680
- fp = _get_file_output(game_name, LEVEL_IDS[LEVELS.index(level)], uid)
681
  if os.path.exists(fp):
682
- raise gr.Error(f"You have done this game already.<br/>{game_name} - {level}")
 
683
  if user is None:
684
  gr.Warning("no user, game will be generated randomly")
685
  # else:
@@ -691,16 +730,19 @@ def check_to_start_new_game(game_name, level, user=None, uid=None):
691
 
692
 
693
  # %%
694
- def check_played_game(solved_games, uid, progress=gr.Progress()):
 
 
695
  matches = _files.list(
696
  q=f"'{_folder_id}' in parents and mimeType='application/octet-stream' and name contains '{uid}_-_'",
697
  fields=f"files(name, id, {_cksm_methods_str})",
698
- ).execute()['files']
 
699
  ret = dict()
700
  for game_name in solved_games.keys():
701
  cur = []
702
  for level, level_id in zip(LEVELS, LEVEL_IDS):
703
- fp_out = _get_file_output(game_name, level_id, uid)
704
  _matches = list(filter(lambda m: fp_out.endswith(m['name']), matches))
705
  if os.path.exists(fp_out):
706
  upload_to_drive(fp_out, _matches)
@@ -708,7 +750,7 @@ def check_played_game(solved_games, uid, progress=gr.Progress()):
708
  download_from_drive(fp_out, _matches)
709
  if os.path.exists(fp_out):
710
  cur.append(level)
711
- ret[game_name] = cur
712
  return ret, gr.update()
713
 
714
 
 
1
  # %%
2
  import os
3
  import time
4
+ import json
5
  import pandas as pd
6
  import gradio as gr
7
  import hashlib
 
20
  from googleapiclient.errors import HttpError
21
  from googleapiclient.http import MediaFileUpload, MediaIoBaseDownload
22
 
23
+ # %%
24
+ _leaderboards = f"{os.getenv('TEXTGAMES_OUTPUT_DIR', '.')}/_leaderboards.jsonl"
25
+
26
 
27
  # %%
28
+ def declare_components(demo, greet, use_login_button=False):
29
  with gr.Row():
30
  with gr.Column(scale=1):
31
  m = gr.Markdown("Welcome to TextGames!", elem_id="md-greeting")
32
+ if use_login_button:
33
+ logout_btn = gr.LoginButton(size='sm')
34
+ reset_sid_btn = gr.Button("♻️ Reset Game Progress", variant='huggingface', size='sm')
35
+ else:
36
+ logout_btn = gr.Button("Logout", link="/logout", variant='huggingface', size='sm', elem_id="btn-logout")
37
+ reset_sid_btn = gr.Button(interactive=False, visible=False, size='sm')
38
  with gr.Column(scale=2):
39
+ solved_games_df = gr.DataFrame(headers=[g.split('\t', 1)[0] for g in GAME_NAMES], label="Attempted Games",
40
+ row_count=(1, 'fixed'), interactive=False, elem_id="df-solved-games")
41
+ level_radio = gr.Radio(LEVELS, label="Level", elem_id="radio-level-name", visible=False)
42
+ game_radio = gr.Radio(GAME_NAMES, label="Game", elem_id="radio-game-name", visible=False)
43
+ new_game_btn = gr.Button("Start Game", elem_id="btn-start-game", visible=False)
44
  render_toggle = gr.Checkbox(False, visible=False, interactive=False)
45
 
46
  # cur_game_start = gr.BrowserState()
 
50
  user_state = gr.State()
51
  uid_state = gr.State()
52
 
53
+ if not os.path.exists(_leaderboards):
54
+ download_from_drive(_leaderboards, compare_checksum=False)
55
+
56
  session_state.change(
57
+ lambda s: session_state_change_fn(s, 2, 0, 3, 0),
58
+ [session_state], [game_radio, level_radio, new_game_btn, logout_btn, reset_sid_btn], js=js_remove_input_helper,
59
  )
60
  new_game_btn.click(check_to_start_new_game, [game_radio, level_radio, user_state, uid_state], [session_state])
61
  solved_games.change(solved_games_change_fn, solved_games, solved_games_df)
 
66
  ).then(
67
  lambda: gr.update(interactive=False), None, [new_game_btn],
68
  ).then(
69
+ check_played_game, [solved_games, user_state], [solved_games, solved_games_df]
70
  ).then(
71
+ lambda uid: ([gr.update(visible=True, interactive=True)] if uid else
72
+ [gr.update(visible=False, interactive=False)]) * 3,
73
+ [uid_state], [level_radio, game_radio, new_game_btn]
74
  )
75
 
76
  return (
77
+ (m, logout_btn, solved_games_df, game_radio, level_radio, new_game_btn, render_toggle, reset_sid_btn),
78
  (session_state, is_solved, solved_games, user_state, uid_state),
79
  )
80
 
 
503
  matches = _files.list(
504
  q=f"'{_folder_id}' in parents and mimeType='{mime_type}' and name = '{fp_out.rsplit('/', 1)[-1]}'",
505
  fields=f"files(name, id, {_cksm_methods_str})",
506
+ ).execute()
507
+ matches = matches['files']
508
  if not os.path.exists(fp_out):
509
  return None, None, matches
510
  with open(fp_out, "rb") as o:
 
517
 
518
 
519
  # %%
520
+ def upload_to_drive(fp_out, matches=None, mime_type="application/octet-stream", compare_checksum=True, update=False):
521
  if compare_checksum:
522
+ same_checksum, _, matches = _is_checksum_same(fp_out, matches, mime_type)
523
  # same_checksum, _, _ = _is_checksum_same(
524
  # fp_out, **{k: v for k, v in [('matches', matches), ('mime_type', mime_type)] if v})
525
  if same_checksum:
 
528
  file_metadata = {"name": fn, "parents": [_folder_id]}
529
  media = MediaFileUpload(fp_out)
530
  try:
531
+ if update and matches:
532
+ file_metadata.pop("parents")
533
+ _files.update(fileId=matches[0]['id'], body=file_metadata, media_body=media).execute()
534
+ else:
535
+ _files.create(body=file_metadata, media_body=media).execute()
536
  except HttpError as error:
537
  msg = f"Failed to upload the file, error: {error}"
538
  print(msg)
 
566
 
567
  # %%
568
  def start_new_game(game_name, level, session_state_component, is_solved_component, solved_games_component,
569
+ user=None, show_timer=False, uid=None, sid=None):
570
  # cur_game_id = GAME_IDS[GAME_NAMES.index(game_name)]
571
  difficulty_level = LEVEL_IDS[LEVELS.index(level)]
572
 
 
574
  # elapsed_text = gr.Textbox("N/A", label=f"{game_name}", info=f"{level}", )
575
  # gr.Timer(.3).tick(_calc_time_elapsed, [cur_game_start, elapsed_text, is_solved_component], [elapsed_text])
576
 
577
+ if (not sid) and user and ('sid' in user):
578
+ sid = user['sid']
579
+
580
+ fp_out = _get_file_output(game_name, difficulty_level, f"{uid}_{sid}")
581
  cur_game = (
582
  new_game(game_name, difficulty_level)
583
  if user is None else
584
  preload_game(game_name, difficulty_level, user)
585
+ if sid is None else
586
+ preload_game(game_name, difficulty_level, user, sid=sid)
587
  )
588
  cur_game.attach_stats_output_(fp_out)
589
  cur_game.flush_stats_(user_info_to_flush=user)
 
640
  js=js_submit)
641
  give_up_checkbox = gr.Checkbox(False, visible=False, interactive=False)
642
  give_up_btn.click(
643
+ lambda: (gr.update(interactive=False), gr.update(interactive=False)), None, [submit_btn, give_up_btn]
644
+ ).then(
645
  lambda x: x, [give_up_checkbox], [give_up_checkbox],
646
  js="(x) => confirm('🥹 Give-up? 💸')"
647
+ ).then(
648
+ lambda: (gr.update(interactive=True), gr.update(interactive=True)), None, [submit_btn, give_up_btn]
649
  )
650
 
651
  def _forfeiting(confirmed, _solved_games):
 
668
  def game_is_solved(_is_solved, _session_state, _solved_games, progress=gr.Progress()):
669
  if _is_solved:
670
  if level in LEVELS and level not in _solved_games[game_name]:
671
+ if isinstance(_solved_games[game_name], str):
672
+ _solved_games[game_name] = []
673
  _solved_games[game_name].append(level)
674
  return (
675
  2,
 
685
 
686
  def finalize_game(_is_solved):
687
  if _is_solved:
688
+ gr.Info(f"Wrapping things up... Please click the button when available...<br/>"
689
+ f"Time: {cur_game.end_timestamp-cur_game.start_timestamp:4.1f} sec. Attempt: {cur_game.attempt_count}.")
690
+ with open(_leaderboards, "a", encoding="utf-8") as f:
691
+ json.dump({'uid': uid, 'sid': sid, 'turns': cur_game.attempt_count,
692
+ 'st': cur_game.start_timestamp, 'ed': cur_game.end_timestamp,
693
+ 'game_name': game_name, 'difficulty_level': difficulty_level,
694
+ }, f)
695
+ f.write("\n")
696
  upload_to_drive(fp_out)
697
+ upload_to_drive(_leaderboards, update=True)
698
  return gr.update(interactive=True)
699
  return gr.update()
700
 
 
711
 
712
 
713
  # %%
714
+ def check_to_start_new_game(game_name, level, user=None, uid=None, sid=None):
715
+ print(game_name, level, uid, sid)
716
  if game_name is None or level is None:
717
  raise gr.Error("please choose both Game & Level")
718
+ fp = _get_file_output(game_name, LEVEL_IDS[LEVELS.index(level)], f"{uid}_{sid}")
719
  if os.path.exists(fp):
720
+ # raise gr.Error(f"You have done this game already.<br/>{game_name} - {level}")
721
+ gr.Warning("You have done this game already. Only first attempt is recorded in the scoreboard.")
722
  if user is None:
723
  gr.Warning("no user, game will be generated randomly")
724
  # else:
 
730
 
731
 
732
  # %%
733
+ def check_played_game(solved_games, user, progress=gr.Progress()):
734
+ uid = user['email']
735
+ sid = user.get('sid', None)
736
  matches = _files.list(
737
  q=f"'{_folder_id}' in parents and mimeType='application/octet-stream' and name contains '{uid}_-_'",
738
  fields=f"files(name, id, {_cksm_methods_str})",
739
+ ).execute()
740
+ matches = matches['files']
741
  ret = dict()
742
  for game_name in solved_games.keys():
743
  cur = []
744
  for level, level_id in zip(LEVELS, LEVEL_IDS):
745
+ fp_out = _get_file_output(game_name, level_id, f"{uid}_{sid}")
746
  _matches = list(filter(lambda m: fp_out.endswith(m['name']), matches))
747
  if os.path.exists(fp_out):
748
  upload_to_drive(fp_out, _matches)
 
750
  download_from_drive(fp_out, _matches)
751
  if os.path.exists(fp_out):
752
  cur.append(level)
753
+ ret[game_name] = cur or '∅'
754
  return ret, gr.update()
755
 
756
 
play_with_auth.py CHANGED
@@ -130,7 +130,7 @@ with gr.Blocks(title="TextGames") as login_demo:
130
  app = gr.mount_gradio_app(app, login_demo, path="/login")
131
 
132
  with gr.Blocks(title="TextGames", css=css, delete_cache=(3600, 3600)) as demo:
133
- ((m, logout_btn, solved_games_df, game_radio, level_radio, new_game_btn, render_toggle),
134
  (session_state, is_solved, solved_games, user_state, uid_state),
135
  ) = declare_components(demo, greet)
136
 
 
130
  app = gr.mount_gradio_app(app, login_demo, path="/login")
131
 
132
  with gr.Blocks(title="TextGames", css=css, delete_cache=(3600, 3600)) as demo:
133
+ ((m, logout_btn, solved_games_df, game_radio, level_radio, new_game_btn, render_toggle, reset_sid_btn),
134
  (session_state, is_solved, solved_games, user_state, uid_state),
135
  ) = declare_components(demo, greet)
136
 
play_with_hf.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ #%%
4
+ import os
5
+ # os.environ.setdefault("GRADIO_SERVER_PORT", "1080")
6
+ # os.environ.setdefault("TEXTGAMES_SHOW_HIDDEN_LEVEL", "1")
7
+ os.environ.setdefault("TEXTGAMES_LOADGAME_DIR", "problemsets")
8
+ os.environ.setdefault("TEXTGAMES_LOADGAME_ID", "42")
9
+ os.environ.setdefault("TEXTGAMES_MOCKUSER", "")
10
+ os.environ.setdefault("TEXTGAMES_OUTPUT_DIR", "user_outputs")
11
+ favicon_path = "textgames-scrabble-black2-ss.png"
12
+
13
+ #%%
14
+ from play_helper import css, declare_components, start_new_game, check_played_game, download_from_drive, upload_to_drive, _leaderboards
15
+ import pandas as pd
16
+ import gradio as gr
17
+ import random
18
+ import json
19
+ from textgames import GAME_NAMES
20
+
21
+
22
+ #%%
23
+ os.makedirs(os.getenv('TEXTGAMES_OUTPUT_DIR', '.'), exist_ok=True)
24
+
25
+
26
+ #%%
27
+ def generate_sid(fp):
28
+ rand_int = random.randint(0, 1000)
29
+ with open(fp, "w", encoding="utf8") as f:
30
+ f.write(f"session_{rand_int:04}\n")
31
+ upload_to_drive(fp, mime_type="text/plain", update=True)
32
+
33
+
34
+ #%%
35
+ def get_sid(uid, force_generate_sid=False):
36
+ fp = f"{os.getenv('TEXTGAMES_OUTPUT_DIR')}/{uid}_sid.txt"
37
+ if force_generate_sid:
38
+ generate_sid(fp)
39
+ if not os.path.exists(fp):
40
+ download_from_drive(fp, mime_type="text/plain", compare_checksum=False)
41
+ if not os.path.exists(fp):
42
+ generate_sid(fp)
43
+ with open(fp, "r", encoding="utf8") as f:
44
+ sid = [_ for _ in f][-1]
45
+ return sid.strip()
46
+
47
+
48
+ #%%
49
+ def greet(request: gr.OAuthProfile | None):
50
+ user = {'email': os.getenv('TEXTGAMES_MOCKUSER', ''), 'name': ""}
51
+ if request is not None:
52
+ user = {'email': request.username, 'name': request.name, 'sid': get_sid(request.username)}
53
+ return f"""
54
+ Welcome to TextGames, {user['name'] or 'please login'}!
55
+ """, user, user['email']
56
+
57
+
58
+ #%%
59
+ with gr.Blocks(title="TextGames", css=css, delete_cache=(3600, 3600)) as demo:
60
+ ((m, logout_btn, solved_games_df, game_radio, level_radio, new_game_btn, render_toggle, reset_sid_btn),
61
+ (session_state, is_solved, solved_games, user_state, uid_state),
62
+ ) = declare_components(demo, greet, use_login_button=True)
63
+ logout_btn.activate()
64
+
65
+ reset_sid_checkbox = gr.Checkbox(False, visible=False, interactive=False)
66
+ reset_sid_btn.click(
67
+ lambda: [gr.update(interactive=False)]*2, None, [reset_sid_btn, new_game_btn]
68
+ ).then(
69
+ lambda x: x, [reset_sid_checkbox], [reset_sid_checkbox],
70
+ js="(x) => confirm('Reset Progress? (cannot be undone)')"
71
+ ).then(
72
+ lambda: [gr.update(interactive=True)]*2, None, [reset_sid_btn, new_game_btn]
73
+ )
74
+
75
+ def _resetting(confirmed, user):
76
+ uid = user.get('email', None) if isinstance(user, dict) else None
77
+ if uid is None:
78
+ gr.Warning("You need to log in first!")
79
+ elif confirmed:
80
+ user['sid'] = get_sid(uid, force_generate_sid=True)
81
+ return user, False
82
+ reset_sid_checkbox.change(
83
+ lambda: [gr.update(interactive=False)]*3, None, [logout_btn, reset_sid_btn, new_game_btn]
84
+ ).then(
85
+ _resetting, [reset_sid_checkbox, user_state], [user_state, reset_sid_checkbox]
86
+ ).then(
87
+ check_played_game, [solved_games, user_state], [solved_games, solved_games_df]
88
+ ).then(
89
+ lambda: [gr.update(interactive=True)]*3, None, [logout_btn, reset_sid_btn, new_game_btn]
90
+ )
91
+
92
+
93
+ @gr.render(inputs=[game_radio, level_radio, user_state, session_state, uid_state], triggers=[render_toggle.change])
94
+ def _start_new_game(game_name, level, user, _session_state, _uid_state):
95
+ if _session_state in [1, 2]:
96
+ start_new_game(game_name, level, session_state, is_solved, solved_games, user=user, uid=_uid_state)
97
+
98
+ #%%
99
+ with demo.route("Leaderboards", "/leaderboard") as demo_leaderboard:
100
+ gr.Markdown("Under Construction. Will be available soon.")
101
+ leaderboards = []
102
+ for tab in ["🚅 Easy", "🚀 Medium", "🛸 Hard"]:
103
+ with gr.Tab(tab):
104
+ leaderboards.append(gr.DataFrame(label="Rankings"))
105
+
106
+ # if os.path.exists(_leaderboards):
107
+ # datas = []
108
+ # with open(_leaderboards, "r", encoding="utf8") as f:
109
+ # for line in f:
110
+ # datas.append(json.loads(line))
111
+ # concat = [{'Level': d['difficulty_level'], 'User': d['uid'], 'Game': d['game_name'].split('\t', 1)[0], 'Attempts': d['turns'],
112
+ # "Time": d['ed'] - d['st']} for d in datas]
113
+ # else:
114
+ def add_dummies():
115
+ return pd.DataFrame({
116
+ 'User': ['dummy'],
117
+ 'Solved': [' '.join([g.split('\t', 1)[0] for g in GAME_NAMES])],
118
+ 'Attempts': [8],
119
+ 'Time': [7200.8],
120
+ })
121
+ for l in leaderboards:
122
+ demo_leaderboard.load(add_dummies, None, [l])
123
+
124
+
125
+ #%%
126
+ # demo.launch()
127
+ demo.launch(
128
+ favicon_path=favicon_path if os.path.exists(favicon_path) else None,
129
+ show_api=False,
130
+ )
131
+
132
+
problemsets/Anagram Scribble_1.json CHANGED
The diff for this file is too large to render. See raw diff
 
problemsets/Anagram Scribble_2.json CHANGED
The diff for this file is too large to render. See raw diff
 
problemsets/Anagram Scribble_3.json CHANGED
The diff for this file is too large to render. See raw diff
 
problemsets/Bracket Game_1.json CHANGED
The diff for this file is too large to render. See raw diff
 
problemsets/Bracket Game_2.json CHANGED
The diff for this file is too large to render. See raw diff
 
problemsets/Bracket Game_3.json CHANGED
The diff for this file is too large to render. See raw diff
 
problemsets/Crossword Arranger_1.json CHANGED
The diff for this file is too large to render. See raw diff
 
problemsets/Crossword Arranger_2.json CHANGED
The diff for this file is too large to render. See raw diff
 
problemsets/Crossword Arranger_3.json CHANGED
The diff for this file is too large to render. See raw diff
 
reval_ana3.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from tqdm import tqdm
4
+ from textgames import GAME_NAMES, game_filename, _game_class_from_name
5
+ from pathlib import Path
6
+
7
+ GAME_NAME = GAME_NAMES[5]
8
+ PROBLEMSET_DIR = Path(os.getenv("TG_PROBLEMSET_DIR", "problemsets"))
9
+ MODEL_OUTPUT_DIR = Path(os.getenv("TG_MODEL_OUTPUT_DIR", "model_outputs"))
10
+ OUTPUT_FILENAMES = [
11
+ # "results_gemma-2-9b-it.1s.jsonl",
12
+ # "results_gemma-2-9b-it.zs.jsonl",
13
+ # "results_gemma-2-27b-it.1s.jsonl",
14
+ # "results_gemma-2-27b-it.zs.jsonl",
15
+ #
16
+ # "results_llama-3.1-8b-instruct.1s.jsonl",
17
+ # "results_llama-3.1-8b-instruct.zs.jsonl",
18
+ # "results_llama-3.1-70b-instruct.1s.jsonl",
19
+ # "results_llama-3.1-70b-instruct.zs.jsonl",
20
+ # "results_llama-3.3-70b-instruct.1s.jsonl",
21
+ # "results_llama-3.3-70b-instruct.zs.jsonl",
22
+ #
23
+ # "results_qwen2-5-7b-instruct.1s.jsonl",
24
+ # "results_qwen2-5-7b-instruct.zs.jsonl",
25
+ # "results_qwen2-5-14b-instruct.1s.jsonl",
26
+ # "results_qwen2-5-14b-instruct.zs.jsonl",
27
+ # "results_qwen2-5-32b-instruct.1s.jsonl",
28
+ # "results_qwen2-5-32b-instruct.zs.jsonl",
29
+ # "results_qwen2-5-72b-instruct.1s.jsonl",
30
+ # "results_qwen2-5-72b-instruct.zs.jsonl",
31
+ #
32
+ # "results_deepseek-r1-distill-14b.1s.jsonl",
33
+ # "results_deepseek-r1-distill-14b.zs.jsonl",
34
+ # "results_deepseek-r1-distill-14b.rerun.1s.jsonl",
35
+ #
36
+ # "results_chatgpt-4o-mini.zs.jsonl",
37
+ # "results_chatgpt-o3-mini.zs.jsonl",
38
+ #
39
+ # "results_qwen2-5-7b-instruct_sp.1s.jsonl",
40
+ # "results_qwen2-5-7b-instruct_sp.zs.jsonl",
41
+
42
+ # "results_deepseek-r1-distill-8b.1s.jsonl",
43
+ "results_deepseek-r1-distill-8b.zs.jsonl",
44
+ ]
45
+
46
+ # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
47
+ # !!! Must run reval_bracket_rerun.py first !!!
48
+ # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
49
+
50
+
51
+ def revalidate_anagram_3(fp, reval_dir="revalidate_anagram_3", source_dir="prior_revalidate"):
52
+ os.makedirs(MODEL_OUTPUT_DIR/reval_dir, exist_ok=True)
53
+ count_pos, count_neg = 0, 0
54
+ with (open(MODEL_OUTPUT_DIR/source_dir/fp, "r", encoding="utf8") as i,
55
+ open(MODEL_OUTPUT_DIR/reval_dir/fp, "w", encoding="utf8") as o,
56
+ tqdm(total=1000, desc=fp) as pbar,
57
+ ):
58
+ for line in i:
59
+ res = json.loads(line)
60
+ if (res['game'] == f"{game_filename(GAME_NAME)}_3"):
61
+ if (res['turn'] == 1):
62
+ cur_sid = res["session"]
63
+ prompt = sid_prompt_dict[cur_sid]
64
+ cur_game = game_cls()
65
+ cur_game.load_game(prompt)
66
+ pbar.update(1)
67
+ elif solved == True:
68
+ continue
69
+ else:
70
+ assert cur_sid == res["session"]
71
+ solved, _ = cur_game.validate(res["response"])
72
+ if solved and not res["solved"]:
73
+ count_pos += 1
74
+ elif not solved and res["solved"]:
75
+ count_neg += 1
76
+ res["solved"] = solved
77
+ o.write(json.dumps(res))
78
+ o.write("\n")
79
+ return count_pos, count_neg
80
+
81
+
82
+ if __name__ == "__main__":
83
+ game_cls = _game_class_from_name(GAME_NAME)
84
+ with open(f"{PROBLEMSET_DIR}/{game_filename(GAME_NAME)}_3.json", "r", encoding="utf8") as f:
85
+ sid_prompt_dict = json.load(f)
86
+ for fp in OUTPUT_FILENAMES:
87
+ print(revalidate_anagram_3(fp))
reval_bracket_all.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from tqdm import tqdm
4
+ from textgames import GAME_NAMES, game_filename, _game_class_from_name
5
+ from pathlib import Path
6
+
7
+ GAME_NAME = GAME_NAMES[6]
8
+ PROBLEMSET_DIR = Path(os.getenv("TG_PROBLEMSET_DIR", "problemsets"))
9
+ MODEL_OUTPUT_DIR = Path(os.getenv("TG_MODEL_OUTPUT_DIR", "model_outputs"))
10
+ OUTPUT_FILENAMES = [
11
+ # "results_gemma-2-9b-it.1s.jsonl",
12
+ # "results_gemma-2-9b-it.zs.jsonl",
13
+ # "results_gemma-2-27b-it.1s.jsonl",
14
+ # "results_gemma-2-27b-it.zs.jsonl",
15
+ #
16
+ # "results_llama-3.1-8b-instruct.1s.jsonl",
17
+ # "results_llama-3.1-8b-instruct.zs.jsonl",
18
+ # "results_llama-3.1-70b-instruct.1s.jsonl",
19
+ # "results_llama-3.1-70b-instruct.zs.jsonl",
20
+ # "results_llama-3.3-70b-instruct.1s.jsonl",
21
+ # "results_llama-3.3-70b-instruct.zs.jsonl",
22
+ #
23
+ # "results_qwen2-5-7b-instruct.1s.jsonl",
24
+ # "results_qwen2-5-7b-instruct.zs.jsonl",
25
+ # "results_qwen2-5-14b-instruct.1s.jsonl",
26
+ # "results_qwen2-5-14b-instruct.zs.jsonl",
27
+ # "results_qwen2-5-32b-instruct.1s.jsonl",
28
+ # "results_qwen2-5-32b-instruct.zs.jsonl",
29
+ # "results_qwen2-5-72b-instruct.1s.jsonl",
30
+ # "results_qwen2-5-72b-instruct.zs.jsonl",
31
+ #
32
+ # "results_deepseek-r1-distill-14b.1s.jsonl",
33
+ # "results_deepseek-r1-distill-14b.zs.jsonl",
34
+ # "results_deepseek-r1-distill-14b.rerun.1s.jsonl",
35
+ #
36
+ # "results_chatgpt-4o-mini.1s.jsonl",
37
+ # "results_chatgpt-4o-mini.zs.jsonl",
38
+ # "results_chatgpt-o3-mini.zs.jsonl",
39
+ #
40
+ # "results_qwen2-5-7b-instruct_sp.1s.jsonl",
41
+ # "results_qwen2-5-7b-instruct_sp.zs.jsonl",
42
+
43
+ # "results_deepseek-r1-distill-8b.1s.jsonl",
44
+ "results_deepseek-r1-distill-8b.zs.jsonl",
45
+ ]
46
+
47
+
48
+ def revalidate_bracket(fp, reval_dir="revalidate_bracket_all",
49
+ source_dirs=("revalidate_bracket_rerun", "revalidate_anagram_3",)):
50
+ os.makedirs(MODEL_OUTPUT_DIR/reval_dir, exist_ok=True)
51
+ count_pos, count_neg = 0, 0
52
+ source_dir = "."
53
+ for source_dir in source_dirs:
54
+ if (MODEL_OUTPUT_DIR / source_dir / fp).exists():
55
+ break
56
+ with (open(MODEL_OUTPUT_DIR / source_dir / fp, "r", encoding="utf8") as i,
57
+ open(MODEL_OUTPUT_DIR / reval_dir / fp, "w", encoding="utf8") as o,
58
+ tqdm(total=3000, desc=fp) as pbar,
59
+ ):
60
+ for line in i:
61
+ res = json.loads(line)
62
+ if (res['game'].startswith(f"{game_filename(GAME_NAME)}")):
63
+ sid_prompt_dict = sid_prompt_dicts[res['game'].rsplit("_", 1)[-1]]
64
+ if (res['turn'] == 1):
65
+ cur_sid = res["session"]
66
+ prompt = sid_prompt_dict[cur_sid]
67
+ cur_game = game_cls()
68
+ cur_game.load_game(prompt)
69
+ pbar.update(1)
70
+ elif solved == True:
71
+ continue
72
+ else:
73
+ assert cur_sid == res["session"]
74
+ solved, _ = cur_game.validate(res["response"])
75
+ if solved and not res["solved"]:
76
+ count_pos += 1
77
+ elif not solved and res["solved"]:
78
+ count_neg += 1
79
+ res["solved"] = solved
80
+ o.write(json.dumps(res))
81
+ o.write("\n")
82
+ return count_pos, count_neg
83
+
84
+
85
+ if __name__ == "__main__":
86
+ def load(k):
87
+ with open(f"{PROBLEMSET_DIR}/{game_filename(GAME_NAME)}_{k}.json", "r", encoding="utf8") as f:
88
+ sid_prompt_dict = json.load(f)
89
+ return sid_prompt_dict
90
+ sid_prompt_dicts = {k: load(k) for k in map(str, range(1, 4))}
91
+ game_cls = _game_class_from_name(GAME_NAME)
92
+ for fp in OUTPUT_FILENAMES:
93
+ print(revalidate_bracket(fp))
94
+
reval_bracket_rerun.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @title ##### Combine Rerun of the Bracket - All
2
+ import os
3
+ import json
4
+ from tqdm import tqdm
5
+ from pathlib import Path
6
+
7
+ MODEL_OUTPUT_DIR = Path(os.getenv("TG_MODEL_OUTPUT_DIR", "model_outputs"))
8
+ fd_new = MODEL_OUTPUT_DIR / "__runs__" / "_redo_bracket"
9
+ fd_ori = MODEL_OUTPUT_DIR / "revalidate_anagram_3"
10
+ fd_out = MODEL_OUTPUT_DIR / "revalidate_bracket_rerun"
11
+
12
+ OUTPUT_FILENAMES = [
13
+ "results_gemma-2-9b-it.1s.jsonl",
14
+ "results_gemma-2-9b-it.zs.jsonl",
15
+ "results_gemma-2-27b-it.1s.jsonl",
16
+ "results_gemma-2-27b-it.zs.jsonl",
17
+
18
+ "results_llama-3.1-8b-instruct.1s.jsonl",
19
+ "results_llama-3.1-8b-instruct.zs.jsonl",
20
+ "results_llama-3.1-70b-instruct.1s.jsonl",
21
+ "results_llama-3.1-70b-instruct.zs.jsonl",
22
+ "results_llama-3.3-70b-instruct.1s.jsonl",
23
+ "results_llama-3.3-70b-instruct.zs.jsonl",
24
+
25
+ "results_qwen2-5-7b-instruct.1s.jsonl",
26
+ "results_qwen2-5-7b-instruct.zs.jsonl",
27
+ "results_qwen2-5-14b-instruct.1s.jsonl",
28
+ "results_qwen2-5-14b-instruct.zs.jsonl",
29
+ "results_qwen2-5-32b-instruct.1s.jsonl",
30
+ "results_qwen2-5-32b-instruct.zs.jsonl",
31
+ "results_qwen2-5-72b-instruct.1s.jsonl",
32
+ "results_qwen2-5-72b-instruct.zs.jsonl",
33
+ ]
34
+
35
+ os.makedirs(fd_out, exist_ok=True)
36
+ for fp in tqdm(OUTPUT_FILENAMES):
37
+ with open(fd_out / fp, "w", encoding="utf8") as o:
38
+ with open(fd_ori / fp, "r", encoding="utf8") as i:
39
+ for line in i:
40
+ res = json.loads(line)
41
+ if res['game'].startswith("Bracket Game"):
42
+ continue
43
+ o.write(line)
44
+ with open((fd_new / fp).with_suffix(".6.jsonl"), "r", encoding="utf8") as i:
45
+ for line in i:
46
+ o.write(line)
reval_crosswords_all.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from tqdm import tqdm
4
+ from textgames import GAME_NAMES, game_filename, _game_class_from_name
5
+ from pathlib import Path
6
+
7
+ GAME_NAME = GAME_NAMES[0]
8
+ PROBLEMSET_DIR = Path(os.getenv("TG_PROBLEMSET_DIR", "problemsets"))
9
+ MODEL_OUTPUT_DIR = Path(os.getenv("TG_MODEL_OUTPUT_DIR", "model_outputs"))
10
+ OUTPUT_FILENAMES = [
11
+ # "results_gemma-2-9b-it.1s.jsonl",
12
+ # "results_gemma-2-9b-it.zs.jsonl",
13
+ # "results_gemma-2-27b-it.1s.jsonl",
14
+ # "results_gemma-2-27b-it.zs.jsonl",
15
+ #
16
+ # "results_llama-3.1-8b-instruct.1s.jsonl",
17
+ # "results_llama-3.1-8b-instruct.zs.jsonl",
18
+ # "results_llama-3.1-70b-instruct.1s.jsonl",
19
+ # "results_llama-3.1-70b-instruct.zs.jsonl",
20
+ # "results_llama-3.3-70b-instruct.1s.jsonl",
21
+ # "results_llama-3.3-70b-instruct.zs.jsonl",
22
+ #
23
+ # "results_qwen2-5-7b-instruct.1s.jsonl",
24
+ # "results_qwen2-5-7b-instruct.zs.jsonl",
25
+ # "results_qwen2-5-14b-instruct.1s.jsonl",
26
+ # "results_qwen2-5-14b-instruct.zs.jsonl",
27
+ # "results_qwen2-5-32b-instruct.1s.jsonl",
28
+ # "results_qwen2-5-32b-instruct.zs.jsonl",
29
+ # "results_qwen2-5-72b-instruct.1s.jsonl",
30
+ # "results_qwen2-5-72b-instruct.zs.jsonl",
31
+ #
32
+ # "results_deepseek-r1-distill-14b.1s.jsonl",
33
+ # "results_deepseek-r1-distill-14b.zs.jsonl",
34
+ # # "results_deepseek-r1-distill-14b.rerun.1s.jsonl",
35
+ #
36
+ # "results_chatgpt-4o-mini.1s.jsonl",
37
+ # "results_chatgpt-4o-mini.zs.jsonl",
38
+ # "results_chatgpt-o3-mini.zs.jsonl",
39
+ #
40
+ # # "results_qwen2-5-7b-instruct_sp.1s.jsonl",
41
+ # # "results_qwen2-5-7b-instruct_sp.zs.jsonl",
42
+
43
+ "results_deepseek-r1-distill-8b.1s.jsonl",
44
+ "results_deepseek-r1-distill-8b.zs.jsonl",
45
+ ]
46
+
47
+
48
+ def revalidate_bracket(fp, reval_dir="revalidate_crosswords_all",
49
+ source_dirs=("revalidate_bracket_all",)):
50
+ os.makedirs(MODEL_OUTPUT_DIR/reval_dir, exist_ok=True)
51
+ count_pos, count_neg = 0, 0
52
+ source_dir = "."
53
+ for source_dir in source_dirs:
54
+ if (MODEL_OUTPUT_DIR / source_dir / fp).exists():
55
+ break
56
+ with (open(MODEL_OUTPUT_DIR / source_dir / fp, "r", encoding="utf8") as i,
57
+ open(MODEL_OUTPUT_DIR / reval_dir / fp, "w", encoding="utf8") as o,
58
+ tqdm(total=3000, desc=fp) as pbar,
59
+ ):
60
+ for line in i:
61
+ res = json.loads(line)
62
+ if (res['game'].startswith(f"{game_filename(GAME_NAME)}")):
63
+ sid_prompt_dict = sid_prompt_dicts[res['game'].rsplit("_", 1)[-1]]
64
+ if (res['turn'] == 1):
65
+ cur_sid = res["session"]
66
+ prompt = sid_prompt_dict[cur_sid]
67
+ cur_game = game_cls()
68
+ cur_game.load_game(prompt)
69
+ pbar.update(1)
70
+ elif solved == True:
71
+ continue
72
+ else:
73
+ assert cur_sid == res["session"]
74
+ solved, _ = cur_game.validate(res["response"])
75
+ if solved and not res["solved"]:
76
+ count_pos += 1
77
+ elif not solved and res["solved"]:
78
+ count_neg += 1
79
+ res["solved"] = solved
80
+ o.write(json.dumps(res))
81
+ o.write("\n")
82
+ return count_pos, count_neg
83
+
84
+
85
+ if __name__ == "__main__":
86
+ def load(k):
87
+ with open(f"{PROBLEMSET_DIR}/{game_filename(GAME_NAME)}_{k}.json", "r", encoding="utf8") as f:
88
+ sid_prompt_dict = json.load(f)
89
+ return sid_prompt_dict
90
+ sid_prompt_dicts = {k: load(k) for k in map(str, range(1, 4))}
91
+ game_cls = _game_class_from_name(GAME_NAME)
92
+ for fp in OUTPUT_FILENAMES:
93
+ print(revalidate_bracket(fp))
94
+
reval_sudoku_all.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from tqdm import tqdm
4
+ from textgames import GAME_NAMES, game_filename, _game_class_from_name
5
+ from pathlib import Path
6
+
7
+ GAME_NAME = GAME_NAMES[1]
8
+ PROBLEMSET_DIR = Path(os.getenv("TG_PROBLEMSET_DIR", "problemsets"))
9
+ MODEL_OUTPUT_DIR = Path(os.getenv("TG_MODEL_OUTPUT_DIR", "model_outputs"))
10
+ OUTPUT_FILENAMES = [
11
+ # "results_gemma-2-9b-it.1s.jsonl",
12
+ # "results_gemma-2-9b-it.zs.jsonl",
13
+ # "results_gemma-2-27b-it.1s.jsonl",
14
+ # "results_gemma-2-27b-it.zs.jsonl",
15
+ #
16
+ # "results_llama-3.1-8b-instruct.1s.jsonl",
17
+ # "results_llama-3.1-8b-instruct.zs.jsonl",
18
+ # "results_llama-3.1-70b-instruct.1s.jsonl",
19
+ # "results_llama-3.1-70b-instruct.zs.jsonl",
20
+ # "results_llama-3.3-70b-instruct.1s.jsonl",
21
+ # "results_llama-3.3-70b-instruct.zs.jsonl",
22
+ #
23
+ # "results_qwen2-5-7b-instruct.1s.jsonl",
24
+ # "results_qwen2-5-7b-instruct.zs.jsonl",
25
+ # "results_qwen2-5-14b-instruct.1s.jsonl",
26
+ # "results_qwen2-5-14b-instruct.zs.jsonl",
27
+ # "results_qwen2-5-32b-instruct.1s.jsonl",
28
+ # "results_qwen2-5-32b-instruct.zs.jsonl",
29
+ # "results_qwen2-5-72b-instruct.1s.jsonl",
30
+ # "results_qwen2-5-72b-instruct.zs.jsonl",
31
+ #
32
+ # "results_deepseek-r1-distill-14b.1s.jsonl",
33
+ # "results_deepseek-r1-distill-14b.zs.jsonl",
34
+ # # "results_deepseek-r1-distill-14b.rerun.1s.jsonl",
35
+ #
36
+ # "results_chatgpt-4o-mini.1s.jsonl",
37
+ # "results_chatgpt-4o-mini.zs.jsonl",
38
+ # "results_chatgpt-o3-mini.zs.jsonl",
39
+ #
40
+ # # "results_qwen2-5-7b-instruct_sp.1s.jsonl",
41
+ # # "results_qwen2-5-7b-instruct_sp.zs.jsonl",
42
+
43
+ "results_deepseek-r1-distill-8b.1s.jsonl",
44
+ "results_deepseek-r1-distill-8b.zs.jsonl",
45
+ ]
46
+
47
+
48
+ def revalidate_bracket(fp, reval_dir="revalidate_sudoku_all",
49
+ source_dirs=("revalidate_crosswords_all",)):
50
+ os.makedirs(MODEL_OUTPUT_DIR/reval_dir, exist_ok=True)
51
+ count_pos, count_neg = 0, 0
52
+ source_dir = "."
53
+ for source_dir in source_dirs:
54
+ if (MODEL_OUTPUT_DIR / source_dir / fp).exists():
55
+ break
56
+ with (open(MODEL_OUTPUT_DIR / source_dir / fp, "r", encoding="utf8") as i,
57
+ open(MODEL_OUTPUT_DIR / reval_dir / fp, "w", encoding="utf8") as o,
58
+ tqdm(total=3000, desc=fp) as pbar,
59
+ ):
60
+ for line in i:
61
+ res = json.loads(line)
62
+ if (res['game'].startswith(f"{game_filename(GAME_NAME)}")):
63
+ sid_prompt_dict = sid_prompt_dicts[res['game'].rsplit("_", 1)[-1]]
64
+ if (res['turn'] == 1):
65
+ cur_sid = res["session"]
66
+ prompt = sid_prompt_dict[cur_sid]
67
+ cur_game = game_cls()
68
+ cur_game.load_game(prompt)
69
+ pbar.update(1)
70
+ elif solved == True:
71
+ continue
72
+ else:
73
+ assert cur_sid == res["session"]
74
+ solved, _ = cur_game.validate(res["response"])
75
+ if solved and not res["solved"]:
76
+ count_pos += 1
77
+ elif not solved and res["solved"]:
78
+ count_neg += 1
79
+ res["solved"] = solved
80
+ o.write(json.dumps(res))
81
+ o.write("\n")
82
+ return count_pos, count_neg
83
+
84
+
85
+ if __name__ == "__main__":
86
+ def load(k):
87
+ with open(f"{PROBLEMSET_DIR}/{game_filename(GAME_NAME)}_{k}.json", "r", encoding="utf8") as f:
88
+ sid_prompt_dict = json.load(f)
89
+ return sid_prompt_dict
90
+ sid_prompt_dicts = {k: load(k) for k in map(str, range(1, 4))}
91
+ game_cls = _game_class_from_name(GAME_NAME)
92
+ for fp in OUTPUT_FILENAMES:
93
+ print(revalidate_bracket(fp))
94
+
textgames-scrabble-black2-ss.png CHANGED

Git LFS Details

  • SHA256: b016722b2f99d17b8a0f164ab82eb528dbd8832997f48e66a90a2c6cd430a754
  • Pointer size: 131 Bytes
  • Size of remote file: 235 kB
textgames/__init__.py CHANGED
@@ -14,8 +14,10 @@ from pandas import read_csv
14
  import json
15
 
16
 
17
- # ["🔑\tPassword Game", "🧩\tText Sudoku", "🗳️\tBracket Game", "📈\tOrdering Text",
18
- # "🏝️\tIslands", "🔎\tString Search", "📰\tCrossword Arranger", "🔤\tAnagram Scribble",]
 
 
19
  THE_GAMES = {
20
  k: v.get_game_name() for k, v in [
21
  ("1", CrosswordArrangerGame),
@@ -60,12 +62,13 @@ def _game_class_from_name(game_name):
60
  return None
61
 
62
 
63
- def preload_game(game_name, level_id, user):
64
  game_cls = _game_class_from_name(game_name)
65
- email_sid_dict = read_csv(
66
- f"{os.getenv('TEXTGAMES_OUTPUT_DIR')}/textgames_userauth.tsv", sep='\t'
67
- ).dropna().set_index("EMAIL").SID.to_dict()
68
- sid = email_sid_dict.get(user["email"])
 
69
  print(f"preload_game('{game_name}', '{level_id}', '{user['email']}') on {sid}")
70
 
71
  with open(f"problemsets/{game_filename(game_name)}_{level_id}.json", "r", encoding="utf8") as f:
 
14
  import json
15
 
16
 
17
+ # [
18
+ # "📰\tCrossword Arranger", "🧩\tText Sudoku", "🏝️\tIslands", "🔑\tPassword Game",
19
+ # "📈\tOrdering Text", "🔤\tAnagram Scribble", "🗳️\tBracket Game", "🔎\tString Search",
20
+ # ]
21
  THE_GAMES = {
22
  k: v.get_game_name() for k, v in [
23
  ("1", CrosswordArrangerGame),
 
62
  return None
63
 
64
 
65
+ def preload_game(game_name, level_id, user, sid=None):
66
  game_cls = _game_class_from_name(game_name)
67
+ if not sid:
68
+ email_sid_dict = read_csv(
69
+ f"{os.getenv('TEXTGAMES_OUTPUT_DIR')}/textgames_userauth.tsv", sep='\t'
70
+ ).dropna().set_index("EMAIL").SID.to_dict()
71
+ sid = email_sid_dict.get(user["email"])
72
  print(f"preload_game('{game_name}', '{level_id}', '{user['email']}') on {sid}")
73
 
74
  with open(f"problemsets/{game_filename(game_name)}_{level_id}.json", "r", encoding="utf8") as f:
textgames/anagram_scribble/anagram_scribble.py CHANGED
@@ -5,6 +5,7 @@ import json
5
  import string
6
  import re
7
 
 
8
  class AnagramScribble(BaseGame):
9
  @staticmethod
10
  def get_game_name() -> str:
@@ -43,6 +44,18 @@ class AnagramScribble(BaseGame):
43
  if total_chars_extraction != "Error loading game state.":
44
  characters = total_chars_extraction.split(",")
45
  self.total_chars = [char.strip().strip("'") for char in characters]
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  def _generate_new_game(self, *args, **kwargs) -> None:
48
  self.low_num_chars = kwargs['low_num_chars']
@@ -57,16 +70,16 @@ class AnagramScribble(BaseGame):
57
 
58
  def _get_prompt(self) -> str:
59
  if self.allow_repeat:
60
- prompt = f"Construct a valid {self.num_chars}-character English word from the following letters:\n{self.total_chars}.\nEach character can be used multiple times. Please write None if there is no valid combination."
61
  else:
62
- prompt = f"Construct a valid {self.num_chars}-character English word from the following letters:\n{self.total_chars}.\nEach character can only be used once. Please write None if there is no valid combination."
63
  return prompt
64
 
65
  def _validate(self, answer: str) -> (bool, str):
66
- answer = answer.lower()
67
- if self.possible_ans != "" and answer == "none":
68
  val_msg = "There is a valid answer."
69
  return False, val_msg
 
70
  if len(answer) != self.num_chars:
71
  val_msg = f"Your answer must be exactly {self.num_chars} characters long"
72
  return False, val_msg
@@ -74,12 +87,31 @@ class AnagramScribble(BaseGame):
74
  if char not in self.total_chars:
75
  val_msg = "Your answer must only contain the characters provided"
76
  return False, val_msg
77
- if (not self.allow_repeat and (len(set(answer)) != len(answer))
78
- and (len(self.possible_ans) == len(set(self.possible_ans)))):
79
- val_msg = "Your answer must not contain repeated characters"
80
- return False, val_msg
 
 
 
 
 
 
 
 
 
81
  if answer not in self.WORD_LIST_BIN[str(self.num_chars)]:
82
  val_msg = "Your answer is not a valid English word"
83
  return False, val_msg
84
 
85
  return True, ""
 
 
 
 
 
 
 
 
 
 
 
5
  import string
6
  import re
7
 
8
+
9
  class AnagramScribble(BaseGame):
10
  @staticmethod
11
  def get_game_name() -> str:
 
44
  if total_chars_extraction != "Error loading game state.":
45
  characters = total_chars_extraction.split(",")
46
  self.total_chars = [char.strip().strip("'") for char in characters]
47
+ self.possible_ans = ""
48
+ _chars = sorted(self.total_chars)
49
+ for w in self.WORD_LIST_BIN[str(self.num_chars)]:
50
+ _ans = sorted(w)
51
+ j, k = 0, 0
52
+ while j < len(_ans) and k < len(_chars):
53
+ if _ans[j] == _chars[k]:
54
+ j += 1
55
+ k += 1
56
+ if j >= len(_ans):
57
+ self.possible_ans = w
58
+ break
59
 
60
  def _generate_new_game(self, *args, **kwargs) -> None:
61
  self.low_num_chars = kwargs['low_num_chars']
 
70
 
71
  def _get_prompt(self) -> str:
72
  if self.allow_repeat:
73
+ prompt = f"Construct a valid {self.num_chars}-character English word from the following letters:\n{self.total_chars}.\nEach character can be used multiple times. Please write None if there is no valid combination. Print only the answer.\n"
74
  else:
75
+ prompt = f"Construct a valid {self.num_chars}-character English word from the following letters:\n{self.total_chars}.\nEach character can only be used once. Please write None if there is no valid combination. Print only the answer.\n"
76
  return prompt
77
 
78
  def _validate(self, answer: str) -> (bool, str):
79
+ if self.possible_ans != "" and answer == "None":
 
80
  val_msg = "There is a valid answer."
81
  return False, val_msg
82
+ answer = answer.lower()
83
  if len(answer) != self.num_chars:
84
  val_msg = f"Your answer must be exactly {self.num_chars} characters long"
85
  return False, val_msg
 
87
  if char not in self.total_chars:
88
  val_msg = "Your answer must only contain the characters provided"
89
  return False, val_msg
90
+ # if (not self.allow_repeat and (len(set(answer)) != len(answer))
91
+ # and (len(self.possible_ans) == len(set(self.possible_ans)))):
92
+ if not self.allow_repeat:
93
+ _ans = sorted(answer)
94
+ _chars = sorted(self.total_chars)
95
+ j, k = 0, 0
96
+ while j < len(_ans) and k < len(_chars):
97
+ if _ans[j] == _chars[k]:
98
+ j += 1
99
+ k += 1
100
+ if j < len(_ans):
101
+ val_msg = "Your answer must not contain repeated characters"
102
+ return False, val_msg
103
  if answer not in self.WORD_LIST_BIN[str(self.num_chars)]:
104
  val_msg = "Your answer is not a valid English word"
105
  return False, val_msg
106
 
107
  return True, ""
108
+
109
+ @staticmethod
110
+ def example() -> (str, str):
111
+ prompt = ("Construct a valid 5-character English word from the following letters:\n"
112
+ "['e', 'l', 'o', 'b', 's', 'p'].\n"
113
+ "Each character can be used multiple times. Please write None if there is no valid combination."
114
+ " Print only the answer.\n")
115
+ answer = "sleep"
116
+ return prompt, answer
117
+
textgames/bracket_game/bracket_game.py CHANGED
@@ -1,5 +1,6 @@
1
  import random
2
  import re
 
3
  from pathlib import Path
4
  from textgames.base_game import BaseGame
5
  #%%
@@ -57,48 +58,94 @@ class BracketGame(BaseGame):
57
  self.MULTI_WORD_LIST.append(self.WORD_LIST[num1] + self.WORD_LIST[num2])
58
 
59
  def _validate(self, answer: str) -> (bool, str):
60
- for rule in self.rules:
61
- arr = answer.split(rule[0])
62
-
63
- if rule[1][1] not in arr[0] or rule[1][2] not in arr[1]:
64
- val_msg = f"{rule[0]} is not between the correct bracket, {rule[1][1]} not in {arr[0]} and {rule[1][2]} not in {arr[1]}"
65
- return False, val_msg
66
-
67
- filter_answer = answer
68
- for i in range(0, 26):
69
- cc = chr(ord("a") + i)
70
- filter_answer = filter_answer.replace(cc,"")
71
 
72
- cc = chr(ord("A") + i)
73
- filter_answer = filter_answer.replace(cc,"")
74
-
75
- open_bracket_list = ["[", "{", "(", "<"]
76
- close_bracket_map = {
77
- "[":"]", "{":"}", "(":")", "<":">"
78
- }
79
-
80
- # check max depth
81
- count = 0
82
- st = []
83
-
84
- for i in range(len(filter_answer)):
85
- if (filter_answer[i] in open_bracket_list):
86
- st.append(filter_answer[i]) # pushing the bracket in the stack
87
- else:
88
- if len(st) > 0 and (filter_answer[i] == close_bracket_map[st[-1]]):
89
- if (count < len(st)):
90
- count = len(st)
91
- st.pop()
92
- else:
 
 
 
93
  val_msg = "There is a closing bracket without an open bracket"
94
  return False, val_msg
95
-
96
- if count == self.depth:
97
- return True, ""
98
- else:
99
- val_msg = f"The depth of the bracket is {count}. The expected depth is {self.depth}"
100
  return False, val_msg
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  def _generate_new_game(self, *args, **kwargs) -> None:
103
  num_words = kwargs["num_words"]
104
  num_rules = kwargs["num_rules"]
@@ -141,6 +188,7 @@ class BracketGame(BaseGame):
141
  prompt = f"You are given a text {self.string} Your job is to put some valid parenthesis brackets in the text such that:\n"
142
  for rule in self.rules:
143
  prompt += f"- \"{rule[0]}\" is inside a {rule[1][0]} bracket\n"
 
144
  prompt += f"The bracket depth must be {self.depth} and print only the answer\n"
145
  return prompt
146
 
@@ -159,7 +207,7 @@ class BracketGame(BaseGame):
159
  else:
160
  return 0
161
 
162
- content = state_string.split("the text such that:")[1].split("\nThe bracket depth must be")[0].split("\n")
163
 
164
  self.words = []
165
  self.rules = []
@@ -188,3 +236,14 @@ class BracketGame(BaseGame):
188
  self.create_multiple_words()
189
 
190
  sort_game_states(self)
 
 
 
 
 
 
 
 
 
 
 
 
1
  import random
2
  import re
3
+ from bisect import bisect_left
4
  from pathlib import Path
5
  from textgames.base_game import BaseGame
6
  #%%
 
58
  self.MULTI_WORD_LIST.append(self.WORD_LIST[num1] + self.WORD_LIST[num2])
59
 
60
  def _validate(self, answer: str) -> (bool, str):
61
+ answer = "".join(answer.split()).lower()
 
 
 
 
 
 
 
 
 
 
62
 
63
+ if ("".join(filter(lambda a: a.isalpha(), answer)) !=
64
+ "".join(filter(lambda a: a.isalpha(), self.string.lower()))):
65
+ val_msg = f"You are not allowed to change the character sequence of base text '{self.string}'."
66
+ return False, val_msg
67
+
68
+ char2type_op = {b[1]: b[0] for b in self.BRACKETS}
69
+ char2type_ed = {b[2]: b[0] for b in self.BRACKETS}
70
+
71
+ depth_count = {b[0]: [(-1, 0)] for b in self.BRACKETS}
72
+
73
+ def push(dc, v):
74
+ cur_depth = dc[-1][-1]
75
+ if cur_depth < 0:
76
+ return False
77
+ dc.append((i, cur_depth + v))
78
+ return True
79
+
80
+ mak, cur_mak = 0, 0
81
+ for i, c in enumerate(answer):
82
+ if c in char2type_op:
83
+ push(depth_count[char2type_op[c]], 1)
84
+ cur_mak += 1
85
+ elif c in char2type_ed:
86
+ if not push(depth_count[char2type_ed[c]], -1):
87
  val_msg = "There is a closing bracket without an open bracket"
88
  return False, val_msg
89
+ cur_mak -= 1
90
+ mak = max(mak, cur_mak)
91
+
92
+ if mak != self.depth:
93
+ val_msg = f"The depth of the bracket is {mak}. The expected depth is {self.depth}"
94
  return False, val_msg
95
 
96
+ for rule in self.rules:
97
+ i = answer.find(rule[0])
98
+ if i < 0:
99
+ val_msg = f"The text '{rule[0]}' is not found in your answer."
100
+ return False, val_msg
101
+
102
+ i_depth = bisect_left(depth_count[rule[1][0]], (i, -1)) - 1
103
+ if depth_count[rule[1][0]][i_depth][-1] < 1:
104
+ val_msg = f"The text '{rule[0]}' is not inside any {rule[1][0]} bracket {rule[1][1]} {rule[1][2]}"
105
+ return False, val_msg
106
+
107
+ # arr = answer.split(rule[0])
108
+ # if rule[1][1] not in arr[0] or rule[1][2] not in arr[1]:
109
+ # val_msg = f"The text '{rule[0]}' is not between the correct bracket, {rule[1][1]} not in {arr[0]} and {rule[1][2]} not in {arr[1]}"
110
+ # return False, val_msg
111
+
112
+ return True, ""
113
+
114
+ # filter_answer = answer
115
+ # for i in range(0, 26):
116
+ # cc = chr(ord("a") + i)
117
+ # filter_answer = filter_answer.replace(cc,"")
118
+ #
119
+ # cc = chr(ord("A") + i)
120
+ # filter_answer = filter_answer.replace(cc,"")
121
+ #
122
+ # open_bracket_list = ["[", "{", "(", "<"]
123
+ # close_bracket_map = {
124
+ # "[":"]", "{":"}", "(":")", "<":">"
125
+ # }
126
+ #
127
+ # # check max depth
128
+ # count = 0
129
+ # st = []
130
+ #
131
+ # for i in range(len(filter_answer)):
132
+ # if (filter_answer[i] in open_bracket_list):
133
+ # st.append(filter_answer[i]) # pushing the bracket in the stack
134
+ # else:
135
+ # if len(st) > 0 and (filter_answer[i] == close_bracket_map[st[-1]]):
136
+ # if (count < len(st)):
137
+ # count = len(st)
138
+ # st.pop()
139
+ # else:
140
+ # val_msg = "There is a closing bracket without an open bracket"
141
+ # return False, val_msg
142
+ #
143
+ # if count == self.depth:
144
+ # return True, ""
145
+ # else:
146
+ # val_msg = f"The depth of the bracket is {count}. The expected depth is {self.depth}"
147
+ # return False, val_msg
148
+
149
  def _generate_new_game(self, *args, **kwargs) -> None:
150
  num_words = kwargs["num_words"]
151
  num_rules = kwargs["num_rules"]
 
188
  prompt = f"You are given a text {self.string} Your job is to put some valid parenthesis brackets in the text such that:\n"
189
  for rule in self.rules:
190
  prompt += f"- \"{rule[0]}\" is inside a {rule[1][0]} bracket\n"
191
+ prompt += "The open and close parenthesis for block is [ ], curly is { }, round is ( ), and angle is < >\n"
192
  prompt += f"The bracket depth must be {self.depth} and print only the answer\n"
193
  return prompt
194
 
 
207
  else:
208
  return 0
209
 
210
+ content = state_string.split("the text such that:")[1].split("\nThe open and close parenthesis ")[0].split("\n")
211
 
212
  self.words = []
213
  self.rules = []
 
236
  self.create_multiple_words()
237
 
238
  sort_game_states(self)
239
+
240
+ @staticmethod
241
+ def example() -> (str, str):
242
+ prompt = ("You are given a text fabuloustextgames Your job is to put some valid parenthesis brackets in the text such that:\n"
243
+ "- \"games\" is inside a round bracket\n"
244
+ "- \"text\" is inside a angle bracket\n"
245
+ "- \"fabulous\" is inside a block bracket\n"
246
+ "The open and close parenthesis for block is [ ], curly is { }, round is ( ), and angle is < >\n"
247
+ "The bracket depth must be 2 and print only the answer\n")
248
+ answer = "[[fabulous]<text>(games)]"
249
+ return prompt, answer
textgames/crossword_arranger/crossword_arranger.py CHANGED
@@ -125,19 +125,47 @@ class CrosswordArrangerGame(BaseGame):
125
  return prompt
126
 
127
  def _validate(self, answer: str) -> (bool, str):
128
- ans_hor = list(filter(None, answer.lower().replace(' ', '\n').split("\n")))
 
 
129
  val_msg = ""
 
 
 
 
130
  if len(ans_hor) != self.board_size:
131
  val_msg = f"Mismatch answer length found!! Expected size of {self.board_size}, got {len(ans_hor)}."
132
  return False, val_msg
 
 
 
 
133
  ans_ver = [''.join(ans_hor[r][c] for r in range(self.board_size)) for c in range(self.board_size)]
134
  word_set = set(self.word_list)
135
- for w in chain(ans_hor, ans_ver):
136
  if w not in word_set:
 
 
137
  return False, val_msg
138
  word_set.remove(w)
139
  return True, val_msg
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  #%%
143
 
 
125
  return prompt
126
 
127
  def _validate(self, answer: str) -> (bool, str):
128
+ answer = answer if answer else ""
129
+ # ans_hor = list(filter(None, answer.lower().replace(' ', '\n').split("\n")))
130
+ ans_hor = answer.lower().split()
131
  val_msg = ""
132
+ if len(ans_hor) != self.board_size:
133
+ arr = answer.lower().split()
134
+ if all(len(l) == 1 for l in arr) and (len(arr) == self.board_size * self.board_size):
135
+ ans_hor = ["".join(arr[i:i+self.board_size]) for i in range(0, len(arr), self.board_size)]
136
  if len(ans_hor) != self.board_size:
137
  val_msg = f"Mismatch answer length found!! Expected size of {self.board_size}, got {len(ans_hor)}."
138
  return False, val_msg
139
+ for w in ans_hor:
140
+ if len(w) != self.board_size:
141
+ val_msg = f"Mismatch answer length found!! Expected size of {self.board_size}, got {len(w)}."
142
+ return False, val_msg
143
  ans_ver = [''.join(ans_hor[r][c] for r in range(self.board_size)) for c in range(self.board_size)]
144
  word_set = set(self.word_list)
145
+ for i, w in enumerate(chain(ans_hor, ans_ver)):
146
  if w not in word_set:
147
+ val_msg = (f"Mismatch answer word found!! {'Horizontal' if i < self.board_size else 'Vertical'} word"
148
+ f" '{w}' is not in the word set.")
149
  return False, val_msg
150
  word_set.remove(w)
151
  return True, val_msg
152
 
153
+ @staticmethod
154
+ def example() -> (str, str):
155
+ prompt = (f"Given a board size of 3x3, arrange a possible crossword puzzle answer from a list of words.\n"
156
+ f"Item in the list can only be used once.\n\n"
157
+ f"List of words:\n"
158
+ f"- app\n"
159
+ f"- all\n"
160
+ f"- and\n"
161
+ f"- lee\n"
162
+ f"- let\n"
163
+ f"- pat\n"
164
+ f"- pee\n"
165
+ f"- pet\n\n"
166
+ f"Print only the answer.")
167
+ answer = "app\nlee\nlet"
168
+ return prompt, answer
169
 
170
  #%%
171
 
textgames/islands/islands.py CHANGED
@@ -99,8 +99,8 @@ class Islands(BaseGame):
99
  answer = [a.replace(" ", "").lower().strip() for a in answer]
100
 
101
  # check the size
102
- if len(answer) != self.N or len(answer[0]) != self.N:
103
- val_msg = f"2D grid is not {self.N} x {self.N}. ({len(answer)} x {len(answer[0])})"
104
  return False, val_msg
105
 
106
  # check the tiles, ensure they are valid
@@ -194,4 +194,16 @@ Your 2D grid must follow the following rules:
194
 
195
  Print only the answer.
196
  """
197
- return prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  answer = [a.replace(" ", "").lower().strip() for a in answer]
100
 
101
  # check the size
102
+ if len(answer) != self.N or any((len(a) < self.N) for a in answer):
103
+ val_msg = f"2D grid is not {self.N} x {self.N}. ({len(answer)} x {set(len(a) for a in answer)})"
104
  return False, val_msg
105
 
106
  # check the tiles, ensure they are valid
 
194
 
195
  Print only the answer.
196
  """
197
+ return prompt
198
+
199
+ @staticmethod
200
+ def example() -> (str, str):
201
+ prompt = ("You are asked to construct a 2D 5 x 5 grid, consisting of water tiles (denoted by \u2019.\u2019), \n"
202
+ "land tiles (denoted by \u2019#\u2019). \n\n"
203
+ "A group of connected land tiles in 4 cardinal directions forms an island.\n\n"
204
+ "Your 2D grid must follow the following rules:\n"
205
+ "- There must be exactly 1 islands.\n"
206
+ "- The size of each island must be from 1 to 2 tiles.\n\n"
207
+ "Print only the answer.\n")
208
+ answer = "...##\n.....\n.....\n.....\n....."
209
+ return prompt, answer
textgames/ordering_text/ordering_text.py CHANGED
@@ -5,10 +5,10 @@ Rules Description
5
 
6
  word length:
7
  - example: word less than 5 characters gets 10 points
8
- - possible operands: {\eq, \lt, \gt, \ne}
9
- - \le and \ge will be randomized for prompt generation
10
- - possible combinations: {\gt\lt, \gt\lt\ne}
11
- - only 1 \ne is considered
12
 
13
  neighboring / consecutive chars
14
  - example: every pair of consecutive consonant gets 5 points
@@ -66,6 +66,15 @@ from textgames.base_game import BaseGame
66
  from textgames.assets.word_list import WORDS_LIST, WORDS_BY_LEN
67
 
68
 
 
 
 
 
 
 
 
 
 
69
  #%%
70
  class Scoring:
71
  def __init__(self, point: int):
@@ -505,14 +514,15 @@ class OrderingTextGame(BaseGame):
505
  return self.answer # sorted(self.words, key=lambda word: (self.get_point(word), word))
506
 
507
  def _validate(self, answer: str) -> (bool, str):
508
- answer = answer.lower().replace(' ', '\n')
509
- if answer != "\n".join(self.get_answer()):
510
- for i, (a, b) in enumerate(zip(answer.split(), self.get_answer()), 1):
511
- if a != b:
512
- val_msg = f"{a} is not supposed to be at index {i}."
513
- return False, val_msg
514
- else:
515
- return True, ""
 
516
 
517
  def _generate_new_game(self, *args, **kwargs) -> None:
518
  if "preset_config" in kwargs:
@@ -588,6 +598,26 @@ class OrderingTextGame(BaseGame):
588
  prompt += "\nPrint only the answer."
589
  return prompt
590
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
591
 
592
  #%%
593
 
 
5
 
6
  word length:
7
  - example: word less than 5 characters gets 10 points
8
+ - possible operands: {\\eq, \\lt, \\gt, \\ne}
9
+ - \\le and \\ge will be randomized for prompt generation
10
+ - possible combinations: {\\gt\\lt, \\gt\\lt\\ne}
11
+ - only 1 \\ne is considered
12
 
13
  neighboring / consecutive chars
14
  - example: every pair of consecutive consonant gets 5 points
 
66
  from textgames.assets.word_list import WORDS_LIST, WORDS_BY_LEN
67
 
68
 
69
+ #%%
70
+ index_to_word = {
71
+ 1: "first", 2: "second", 3: "third", 4: "fourth", 5: "fifth",
72
+ 6: "sixth", 7: "seventh", 8: "eighth", 9: "ninth", 10: "tenth",
73
+ 11: "eleventh", 12: "twelfth", 13: "thirteenth", 14: "fourteenth",
74
+ 15: "fifteenth", 16: "sixteenth", 17: "seventeenth", 18: "eighteenth",
75
+ }
76
+
77
+
78
  #%%
79
  class Scoring:
80
  def __init__(self, point: int):
 
514
  return self.answer # sorted(self.words, key=lambda word: (self.get_point(word), word))
515
 
516
  def _validate(self, answer: str) -> (bool, str):
517
+ answer = answer.lower().replace(',', ' ').split()
518
+ gold = self.get_answer()
519
+ if len(answer) < len(gold):
520
+ return False, f"Your answer is too short. There should be {len(gold)} items."
521
+ for i, (a, b) in enumerate(zip(answer, self.get_answer()), 1):
522
+ if a != b:
523
+ val_msg = f"'{a}' is not supposed to be the {index_to_word[i]} word in the order."
524
+ return False, val_msg
525
+ return True, ""
526
 
527
  def _generate_new_game(self, *args, **kwargs) -> None:
528
  if "preset_config" in kwargs:
 
598
  prompt += "\nPrint only the answer."
599
  return prompt
600
 
601
+ @staticmethod
602
+ def example() -> (str, str):
603
+ prompt = ("Given a set of rules to calculate point, sort the set of words in decreasing order.\n"
604
+ "When there 2 or more words with same point, sort lexicographically.\n\n"
605
+ "Rules:\n"
606
+ "- add 10 points if there exists 'u' in the word\n\n"
607
+ "Words:\n"
608
+ "- hudi\n"
609
+ "- genta\n"
610
+ "- aji\n"
611
+ "- ruochen\n\n"
612
+ "Print only the answer.")
613
+ answer = (
614
+ "hudi\n"
615
+ "ruochen\n"
616
+ "aji\n"
617
+ "genta"
618
+ )
619
+ return prompt, answer
620
+
621
 
622
  #%%
623
 
textgames/password_game/password_game.py CHANGED
@@ -274,3 +274,13 @@ class PasswordGame(BaseGame):
274
  self.rules = [rule for rule in new_rules]
275
 
276
  sort_game_states(self)
 
 
 
 
 
 
 
 
 
 
 
274
  self.rules = [rule for rule in new_rules]
275
 
276
  sort_game_states(self)
277
+
278
+ @staticmethod
279
+ def example() -> (str, str):
280
+ prompt = ("Please write a text string without any space by following a set of given rules."
281
+ " Please write only the answer and follow the following criteria:\n"
282
+ "- the text has 6 english character\n"
283
+ "- the text has 0 uppercase characters\n")
284
+ answer = "hoodie"
285
+ return prompt, answer
286
+
textgames/string_search/string_search.py CHANGED
@@ -309,4 +309,16 @@ Find a substring of exactly {self.answer_len} characters long that:
309
  {extra_constraints}
310
  Print only the answer.
311
  """
312
- return prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  {extra_constraints}
310
  Print only the answer.
311
  """
312
+ return prompt
313
+
314
+ @staticmethod
315
+ def example() -> (str, str):
316
+ prompt = ("You are given the following string:\n"
317
+ "hudigentaajiruochen\n\n"
318
+ "Find a substring of exactly 3 characters long that:\n"
319
+ " - Contains t\n"
320
+ " - Does not contain i and a\n\n"
321
+ "Print only the answer.\n")
322
+ answer = "ent"
323
+ return prompt, answer
324
+
textgames/sudoku/sudoku.py CHANGED
@@ -9,6 +9,7 @@ Please solve the 9x9 sudoku with 1,2,3,4,5,6,7,8,9 as the values and fill _ with
9
  Print only the answer.
10
  """
11
 
 
12
  #%%
13
  class Sudoku(BaseGame):
14
  @staticmethod
@@ -28,34 +29,47 @@ class Sudoku(BaseGame):
28
  for j in range(self.size):
29
  num = mat[i][j]
30
  if num == self.empty_character:
31
- return False
32
 
33
  subgrid_index = (i // self.srn) * self.srn + (j // self.srn)
34
 
35
- if num in rows[i] or num in cols[j] or num in subgrids[subgrid_index]:
36
- return False
37
-
 
 
 
 
38
  rows[i].add(num)
39
  cols[j].add(num)
40
  subgrids[subgrid_index].add(num)
41
 
42
- return True
43
 
44
  def _validate(self, input) -> (bool, str):
45
  mat = [[self.empty_character for i in range(self.size)] for j in range(self.size)]
46
 
 
47
  arr = input.split()
 
 
 
 
 
 
 
48
  for i in range(len(arr)):
49
  for j in range(len(arr[i])):
50
  if arr[i][j] not in self.char_to_id:
51
- val_msg = "Found unrecognized character(s)"
52
  return False, val_msg
53
 
54
  mat[i][j] = self.char_to_id[arr[i][j]]
55
  if arr[i][j] != self.mat[i][j] and self.mat[i][j] != self.empty_character:
56
  val_msg = "One or more characters are replaced"
57
  return False, val_msg
58
- return self.is_valid_sudoku(mat), ""
 
59
 
60
  def _generate_new_game(self, *args, **kwargs) -> None:
61
  size=kwargs["size"]
@@ -228,4 +242,14 @@ class Sudoku(BaseGame):
228
  self.char_to_id = {}
229
  for c_id in range(len(self.characters)):
230
  self.char_to_id[self.characters[c_id]] = c_id
231
-
 
 
 
 
 
 
 
 
 
 
 
9
  Print only the answer.
10
  """
11
 
12
+
13
  #%%
14
  class Sudoku(BaseGame):
15
  @staticmethod
 
29
  for j in range(self.size):
30
  num = mat[i][j]
31
  if num == self.empty_character:
32
+ return False, "There are unfilled cells"
33
 
34
  subgrid_index = (i // self.srn) * self.srn + (j // self.srn)
35
 
36
+ if num in rows[i]:
37
+ return False, f"Duplicated row value ({num}) for cell in row {i+1} column {j+1}."
38
+ elif num in cols[j]:
39
+ return False, f"Duplicated column value ({num}) for cell in row {i+1} column {j+1}."
40
+ elif num in subgrids[subgrid_index]:
41
+ return False, f"Duplicated subgrid value ({num}) for cell in row {i+1} column {j+1}."
42
+
43
  rows[i].add(num)
44
  cols[j].add(num)
45
  subgrids[subgrid_index].add(num)
46
 
47
+ return True, ""
48
 
49
  def _validate(self, input) -> (bool, str):
50
  mat = [[self.empty_character for i in range(self.size)] for j in range(self.size)]
51
 
52
+ input = input if input else ""
53
  arr = input.split()
54
+ if all(len(l) == 1 for l in arr) and (len(arr) == self.size * self.size):
55
+ arr = ["".join(arr[i:i+self.size]) for i in range(0, len(arr), self.size)]
56
+ if (len(arr) != self.size) or any(len(arr[i]) != self.size for i in range(len(arr))):
57
+ arr = input.split("\n")
58
+ val_msg = f"Your answer is wrong in shape, it should be {self.size}x{self.size} sudoku."
59
+ return False, val_msg
60
+
61
  for i in range(len(arr)):
62
  for j in range(len(arr[i])):
63
  if arr[i][j] not in self.char_to_id:
64
+ val_msg = "There are unrecognized characters, or possibly unfilled cells."
65
  return False, val_msg
66
 
67
  mat[i][j] = self.char_to_id[arr[i][j]]
68
  if arr[i][j] != self.mat[i][j] and self.mat[i][j] != self.empty_character:
69
  val_msg = "One or more characters are replaced"
70
  return False, val_msg
71
+
72
+ return self.is_valid_sudoku(mat)
73
 
74
  def _generate_new_game(self, *args, **kwargs) -> None:
75
  size=kwargs["size"]
 
242
  self.char_to_id = {}
243
  for c_id in range(len(self.characters)):
244
  self.char_to_id[self.characters[c_id]] = c_id
245
+
246
+ @staticmethod
247
+ def example() -> (str, str):
248
+ prompt = ("Please solve the 4x4 sudoku with A,B,C,D as the values and fill _ with the possible value and"
249
+ " only print the answer. Follow the sudoku rule.\nA_CD CD_B _AD_ DCBA")
250
+ answer = ("ABCD\n"
251
+ "CDAB\n"
252
+ "BADC\n"
253
+ "DCBA")
254
+ return prompt, answer
255
+