Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- .gitignore +2 -3
- agents/Gemma-2-9b-it.ipynb +1 -0
- agents/__init__.py +5 -0
- agents/_reference.py +216 -0
- agents/chatgpt.py +145 -0
- agents/dsr1_distill.py +138 -0
- agents/gemma_2_9b_it.py +104 -0
- agents/llama3.py +102 -0
- agents/qwen2_5_7b_instruct.py +112 -0
- agents/qwen2_5_math.py +137 -0
- agents/runner.py +89 -0
- play_gradio.py +2 -2
- play_helper.py +69 -27
- play_with_auth.py +1 -1
- play_with_hf.py +132 -0
- problemsets/Anagram Scribble_1.json +0 -0
- problemsets/Anagram Scribble_2.json +0 -0
- problemsets/Anagram Scribble_3.json +0 -0
- problemsets/Bracket Game_1.json +0 -0
- problemsets/Bracket Game_2.json +0 -0
- problemsets/Bracket Game_3.json +0 -0
- problemsets/Crossword Arranger_1.json +0 -0
- problemsets/Crossword Arranger_2.json +0 -0
- problemsets/Crossword Arranger_3.json +0 -0
- reval_ana3.py +87 -0
- reval_bracket_all.py +94 -0
- reval_bracket_rerun.py +46 -0
- reval_crosswords_all.py +94 -0
- reval_sudoku_all.py +94 -0
- textgames-scrabble-black2-ss.png +0 -0
- textgames/__init__.py +10 -7
- textgames/anagram_scribble/anagram_scribble.py +40 -8
- textgames/bracket_game/bracket_game.py +97 -38
- textgames/crossword_arranger/crossword_arranger.py +30 -2
- textgames/islands/islands.py +15 -3
- textgames/ordering_text/ordering_text.py +42 -12
- textgames/password_game/password_game.py +10 -0
- textgames/string_search/string_search.py +13 -1
- textgames/sudoku/sudoku.py +32 -8
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
textgames-scrabble-black2-ss.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
CHANGED
@@ -1,10 +1,9 @@
|
|
1 |
*/*.DS_Store
|
2 |
.DS_Store
|
3 |
|
4 |
-
|
5 |
-
problemsets_*
|
6 |
-
|
7 |
user_outputs/
|
|
|
8 |
|
9 |
.idea/
|
10 |
|
|
|
1 |
*/*.DS_Store
|
2 |
.DS_Store
|
3 |
|
4 |
+
agents/*.sh
|
|
|
|
|
5 |
user_outputs/
|
6 |
+
model_outputs/__runs__
|
7 |
|
8 |
.idea/
|
9 |
|
agents/Gemma-2-9b-it.ipynb
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":1,"metadata":{"id":"Rli_enT6lBDT","executionInfo":{"status":"ok","timestamp":1737395007014,"user_tz":-540,"elapsed":5212,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}}},"outputs":[],"source":["##%%\n","import os\n","import pickle\n","import json\n","# import random\n","# import torch\n","# import numpy as np\n","# import argparse\n","# import cohere\n","# from openai import OpenAI\n"]},{"cell_type":"code","source":["##%%\n","# import hashlib\n","from tqdm import tqdm\n","from itertools import product\n","# from collections import Counter\n","\n","# from transformers import LlamaForCausalLM, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM\n","from transformers import AutoTokenizer, AutoModelForCausalLM\n","from textgames import GAME_NAMES, LEVEL_IDS, game_filename, _game_class_from_name\n"],"metadata":{"id":"dp1F32B8oSfD","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1737395010583,"user_tz":-540,"elapsed":3547,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}},"outputId":"e9adeb5f-70eb-4ca9-dcbb-428e4b28ab41"},"execution_count":2,"outputs":[{"output_type":"stream","name":"stderr","text":["/home/is/frederikus-h/miniconda3/envs/textgame/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n"," from .autonotebook import tqdm as notebook_tqdm\n"]}]},{"cell_type":"code","source":["os.environ.setdefault(\"TEXTGAMES_OUTPUT_DIR\", \"user_outputs\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"2wEu1V1wvxn0","executionInfo":{"status":"ok","timestamp":1737395010664,"user_tz":-540,"elapsed":67,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}},"outputId":"cdcad20f-e357-4009-9f4f-0d4495ebd894"},"execution_count":3,"outputs":[{"output_type":"execute_result","data":{"text/plain":["'user_outputs'"]},"metadata":{},"execution_count":3}]},{"cell_type":"code","source":["##%%\n","gen_model_checkpoint = \"google/gemma-2-9b-it\"\n","quantize = True"],"metadata":{"id":"jZF8bkUcojTX","executionInfo":{"status":"ok","timestamp":1737395010678,"user_tz":-540,"elapsed":13,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}}},"execution_count":4,"outputs":[]},{"cell_type":"code","source":["kwargs = {\n"," \"device_map\": \"auto\",\n","} if quantize else {}"],"metadata":{"id":"VAF5sR9arYzS","executionInfo":{"status":"ok","timestamp":1737395010683,"user_tz":-540,"elapsed":2,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}}},"execution_count":5,"outputs":[]},{"cell_type":"code","source":["##%%\n","gen_model = AutoModelForCausalLM.from_pretrained(gen_model_checkpoint, **kwargs)\n","tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, **kwargs)"],"metadata":{"id":"tzqldl8ooRVL","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1737395038547,"user_tz":-540,"elapsed":27859,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}},"outputId":"902b638c-e6ce-4f8a-bba2-e9f7241c9a27"},"execution_count":6,"outputs":[{"output_type":"stream","name":"stderr","text":["Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:24<00:00, 6.19s/it]\n"]}]},{"cell_type":"code","source":["gen_model.device"],"metadata":{"id":"FeBUXdkWsWrL","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1737395038552,"user_tz":-540,"elapsed":3,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}},"outputId":"6437d1b7-02f8-47f5-d519-e979cefde795"},"execution_count":7,"outputs":[{"output_type":"execute_result","data":{"text/plain":["device(type='cuda', index=0)"]},"metadata":{},"execution_count":7}]},{"cell_type":"code","source":["def get_gemma_response(text):\n"," # global gen_model, tokenizer\n"," messages = [\n"," {\"role\": \"user\", \"content\": text},\n"," ]\n","\n"," input_ids = tokenizer.apply_chat_template(\n"," messages,\n"," add_generation_prompt=True,\n"," return_tensors=\"pt\"\n"," ).to(gen_model.device)\n","\n"," terminators = [\n"," tokenizer.eos_token_id,\n"," tokenizer.convert_tokens_to_ids(\"<|eot_id|>\")\n"," ]\n","\n"," outputs = gen_model.generate(\n"," input_ids,\n"," max_new_tokens=100,\n"," eos_token_id=terminators,\n"," do_sample=True,\n"," temperature=.001,\n"," top_p=1,\n"," )\n","\n"," response = outputs[0][input_ids.shape[-1]:]\n"," return tokenizer.decode(response, skip_special_tokens=True)"],"metadata":{"id":"R5D4K-P2sPaj","executionInfo":{"status":"ok","timestamp":1737395038554,"user_tz":-540,"elapsed":1,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}}},"execution_count":8,"outputs":[]},{"cell_type":"markdown","source":["---\n","Example Call"],"metadata":{"id":"s5FEwOOvxf4h"}},{"cell_type":"code","source":["# @title\n","text = \\\n","\"\"\"\n","Given a set of rules to calculate point, sort the set of words in decreasing order.\n","When there 2 or more words with same point, sort lexicographically.\n","\n","Rules:\n","- every pair of consecutive consonant gets 5 points\n","- every pair of consecutive vowel gets 3 points\n","- add 1 point if there exists exactly 1 'g' in the word\n","- word less than 5 characters gets 10 points\n","- word starts with gen gets 100 points\n","- word ends with ta gets -1000 point\n","\n","Words:\n","- genta\n","- winata\n","- hudi\n","- alham\n","- aji\n","- ruochen\n","\n","Print only the answer.\n","\"\"\"\n","\n","print(text)"],"metadata":{"id":"T_tk4hTGsxsR","colab":{"base_uri":"https://localhost:8080/"},"cellView":"form","executionInfo":{"status":"ok","timestamp":1737392776367,"user_tz":-540,"elapsed":27,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}},"outputId":"d5ea884f-d0fa-4134-ecd9-690eab51c976"},"execution_count":14,"outputs":[{"output_type":"stream","name":"stdout","text":["\n","Given a set of rules to calculate point, sort the set of words in decreasing order.\n","When there 2 or more words with same point, sort lexicographically.\n","\n","Rules:\n","- every pair of consecutive consonant gets 5 points\n","- every pair of consecutive vowel gets 3 points\n","- add 1 point if there exists exactly 1 'g' in the word\n","- word less than 5 characters gets 10 points\n","- word starts with gen gets 100 points\n","- word ends with ta gets -1000 point\n","\n","Words:\n","- genta\n","- winata\n","- hudi\n","- alham\n","- aji\n","- ruochen\n","\n","Print only the answer.\n","\n"]}]},{"cell_type":"code","source":["# Gold Answer:\n","# - aji 10\n","# - hudi 10\n","# - ruochen 5 3\n","# - alham 5\n","# - genta 5 1 100 -1000\n","# - winata -1000"],"metadata":{"id":"G-5yS4S-rdsN"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["print(get_gemma_response(text))"],"metadata":{"id":"05OI36v6vGoY","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1737392724119,"user_tz":-540,"elapsed":6741,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}},"outputId":"fe5d6ed2-d063-4f1c-b2e1-b3af8dbc456e"},"execution_count":9,"outputs":[{"output_type":"stream","name":"stdout","text":["genta\n","winata\n","ruochen\n","hudi\n","alham\n","aji \n","\n"]}]},{"cell_type":"markdown","source":["---\n","Automate run all sessions"],"metadata":{"id":"cxJ4WqHpxi75"}},{"cell_type":"code","source":["for game_name, difficulty_level in product([GAME_NAMES[4], *GAME_NAMES[:4], *GAME_NAMES[5:]], LEVEL_IDS[:3]):\n"," game_cls = _game_class_from_name(game_name)\n"," with open(f\"problemsets/{game_filename(game_name)}_{difficulty_level}.json\", \"r\", encoding=\"utf8\") as f:\n"," sid_prompt_dict = json.load(f)\n","\n"," correct_cnt = 0\n"," for sid, prompt in tqdm(list(sid_prompt_dict.items()), desc=f\"{game_filename(game_name)}_-_{difficulty_level}\"):\n"," cur_game = game_cls()\n"," cur_game.load_game(prompt)\n"," response = get_gemma_response(cur_game.get_prompt()).strip()\n"," solved, val_msg = cur_game.validate(response)\n"," with open(f\"model_outputs/results_gemma_2_9B_it.pkl\", \"ab\") as o:\n"," pickle.dump((f\"{game_filename(game_name)}_{difficulty_level}\", sid, response, (solved, val_msg)), o)\n"," if solved:\n"," correct_cnt += 1\n","\n"," print(f\"{game_name}_-_{difficulty_level}\")\n"," print(f\" Acc.: {correct_cnt / len(sid_prompt_dict):.2%}\")"],"metadata":{"id":"hCTXYpXa1UQ6"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"GC-zkVI52IJX"},"execution_count":null,"outputs":[]}]}
|
agents/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Define the __all__ variable
|
2 |
+
__all__ = ["run_with_agent"]
|
3 |
+
|
4 |
+
# Import the submodules
|
5 |
+
from .runner import run_with_agent
|
agents/_reference.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
import argparse
|
6 |
+
import json
|
7 |
+
import cohere
|
8 |
+
from openai import OpenAI
|
9 |
+
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
from collections import Counter
|
13 |
+
|
14 |
+
from transformers import LlamaForCausalLM, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
|
15 |
+
import hashlib
|
16 |
+
|
17 |
+
|
18 |
+
OPENAI_TOKEN = ""
|
19 |
+
COHERE_TOKEN = ""
|
20 |
+
HF_TOKEN = ""
|
21 |
+
|
22 |
+
|
23 |
+
def argmax(array):
|
24 |
+
"""argmax with deterministic pseudorandom tie breaking."""
|
25 |
+
max_indices = np.arange(len(array))[array == np.max(array)]
|
26 |
+
idx = int(hashlib.sha256(np.asarray(array).tobytes()).hexdigest(), 16) % len(max_indices)
|
27 |
+
return max_indices[idx]
|
28 |
+
|
29 |
+
|
30 |
+
def logsumexp(x):
|
31 |
+
c = x.max()
|
32 |
+
return c + np.log(np.sum(np.exp(x - c)))
|
33 |
+
|
34 |
+
|
35 |
+
def normalize(x):
|
36 |
+
x = np.array(x)
|
37 |
+
return np.exp(x - logsumexp(x))
|
38 |
+
|
39 |
+
|
40 |
+
def set_seed(seed):
|
41 |
+
random.seed(seed)
|
42 |
+
np.random.seed(seed)
|
43 |
+
torch.manual_seed(seed)
|
44 |
+
torch.cuda.manual_seed(seed)
|
45 |
+
|
46 |
+
|
47 |
+
def get_commandr_chat_response(gen_model, gen_model_checkpoint, text, seed):
|
48 |
+
response = gen_model.chat(
|
49 |
+
model="command-r",
|
50 |
+
message=text,
|
51 |
+
temperature=0,
|
52 |
+
max_tokens=64,
|
53 |
+
seed=seed,
|
54 |
+
p=1
|
55 |
+
)
|
56 |
+
return response.text
|
57 |
+
|
58 |
+
|
59 |
+
def get_mt0_response(gen_model, tokenizer, gen_model_checkpoint, text, seed):
|
60 |
+
input_ids = tokenizer.encode(text, return_tensors="pt").to(gen_model.device)
|
61 |
+
|
62 |
+
outputs = gen_model.generate(
|
63 |
+
input_ids,
|
64 |
+
max_new_tokens=10,
|
65 |
+
do_sample=True,
|
66 |
+
temperature=0.2,
|
67 |
+
top_p=1
|
68 |
+
)
|
69 |
+
|
70 |
+
response = outputs[0]
|
71 |
+
return tokenizer.decode(response, skip_special_tokens=True)
|
72 |
+
|
73 |
+
|
74 |
+
def get_gemma_response(gen_model, tokenizer, gen_model_checkpoint, text, seed):
|
75 |
+
messages = [
|
76 |
+
{"role": "user", "content": text},
|
77 |
+
]
|
78 |
+
|
79 |
+
input_ids = tokenizer.apply_chat_template(
|
80 |
+
messages,
|
81 |
+
add_generation_prompt=True,
|
82 |
+
return_tensors="pt"
|
83 |
+
).to(gen_model.device)
|
84 |
+
|
85 |
+
terminators = [
|
86 |
+
tokenizer.eos_token_id,
|
87 |
+
tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
88 |
+
]
|
89 |
+
|
90 |
+
outputs = gen_model.generate(
|
91 |
+
input_ids,
|
92 |
+
max_new_tokens=10,
|
93 |
+
eos_token_id=terminators,
|
94 |
+
do_sample=True,
|
95 |
+
temperature=0.2,
|
96 |
+
top_p=1
|
97 |
+
)
|
98 |
+
|
99 |
+
response = outputs[0][input_ids.shape[-1]:]
|
100 |
+
return tokenizer.decode(response, skip_special_tokens=True)
|
101 |
+
|
102 |
+
|
103 |
+
def get_mistral_instruct_chat_response(gen_model, tokenizer, gen_model_checkpoint, text, seed):
|
104 |
+
messages = [
|
105 |
+
{"role": "user", "content": text},
|
106 |
+
]
|
107 |
+
|
108 |
+
input_ids = tokenizer.apply_chat_template(
|
109 |
+
messages,
|
110 |
+
add_generation_prompt=True,
|
111 |
+
return_tensors="pt"
|
112 |
+
).to(gen_model.device)
|
113 |
+
|
114 |
+
terminators = [
|
115 |
+
tokenizer.eos_token_id,
|
116 |
+
tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
117 |
+
]
|
118 |
+
|
119 |
+
outputs = gen_model.generate(
|
120 |
+
input_ids,
|
121 |
+
max_new_tokens=10,
|
122 |
+
eos_token_id=terminators,
|
123 |
+
do_sample=True,
|
124 |
+
temperature=0.2,
|
125 |
+
top_p=1
|
126 |
+
)
|
127 |
+
|
128 |
+
response = outputs[0][input_ids.shape[-1]:]
|
129 |
+
return tokenizer.decode(response, skip_special_tokens=True)
|
130 |
+
|
131 |
+
|
132 |
+
def get_llama3_instruct_chat_response(gen_model, tokenizer, gen_model_checkpoint, text, seed):
|
133 |
+
messages = [
|
134 |
+
{"role": "user", "content": text},
|
135 |
+
]
|
136 |
+
|
137 |
+
input_ids = tokenizer.apply_chat_template(
|
138 |
+
messages,
|
139 |
+
add_generation_prompt=True,
|
140 |
+
return_tensors="pt"
|
141 |
+
).to(gen_model.device)
|
142 |
+
|
143 |
+
terminators = [
|
144 |
+
tokenizer.eos_token_id,
|
145 |
+
tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
146 |
+
]
|
147 |
+
|
148 |
+
outputs = gen_model.generate(
|
149 |
+
input_ids,
|
150 |
+
max_new_tokens=10,
|
151 |
+
eos_token_id=terminators,
|
152 |
+
do_sample=True,
|
153 |
+
temperature=0.2,
|
154 |
+
top_p=1
|
155 |
+
)
|
156 |
+
|
157 |
+
response = outputs[0][input_ids.shape[-1]:]
|
158 |
+
return tokenizer.decode(response, skip_special_tokens=True)
|
159 |
+
|
160 |
+
|
161 |
+
def get_openai_chat_response(gen_model, gen_model_checkpoint, text, seed):
|
162 |
+
messages = [
|
163 |
+
{
|
164 |
+
"role": "user",
|
165 |
+
"content": text
|
166 |
+
}
|
167 |
+
]
|
168 |
+
response = gen_model.chat.completions.create(
|
169 |
+
model=gen_model_checkpoint,
|
170 |
+
messages=messages,
|
171 |
+
temperature=0,
|
172 |
+
max_tokens=64,
|
173 |
+
top_p=1,
|
174 |
+
seed=seed
|
175 |
+
)
|
176 |
+
return response.choices[0].message.content
|
177 |
+
|
178 |
+
|
179 |
+
def load_model(gen_model_checkpoint, load_in_8bit=False):
|
180 |
+
gen_model = None
|
181 |
+
tokenizer = None
|
182 |
+
|
183 |
+
if "mistralai/Mistral-7B-Instruct-v0.3" in gen_model_checkpoint or "meta-llama/Meta-Llama-3-8B-Instruct" in gen_model_checkpoint or "google/gemma-1.1-7b-it" in gen_model_checkpoint:
|
184 |
+
if load_in_8bit:
|
185 |
+
gen_model = AutoModelForCausalLM.from_pretrained(gen_model_checkpoint, token=HF_TOKEN, device_map="auto",
|
186 |
+
load_in_8bit=True)
|
187 |
+
tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, token=HF_TOKEN, device_map="auto",
|
188 |
+
load_in_8bit=True)
|
189 |
+
else:
|
190 |
+
gen_model = AutoModelForCausalLM.from_pretrained(gen_model_checkpoint, token=HF_TOKEN)
|
191 |
+
tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, token=HF_TOKEN)
|
192 |
+
elif "CohereForAI/aya-101" in gen_model_checkpoint or "bigscience/mt0" in gen_model_checkpoint:
|
193 |
+
if load_in_8bit:
|
194 |
+
gen_model = AutoModelForSeq2SeqLM.from_pretrained(gen_model_checkpoint, token=HF_TOKEN, device_map="auto",
|
195 |
+
load_in_8bit=True)
|
196 |
+
tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, token=HF_TOKEN, device_map="auto",
|
197 |
+
load_in_8bit=True)
|
198 |
+
else:
|
199 |
+
gen_model = AutoModelForSeq2SeqLM.from_pretrained(gen_model_checkpoint, token=HF_TOKEN)
|
200 |
+
tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, token=HF_TOKEN)
|
201 |
+
elif "facebook/xglm" in gen_model_checkpoint or "bigscience/bloomz" in gen_model_checkpoint or "aya-23-8B" in args.gen_model_checkpoint:
|
202 |
+
if load_in_8bit:
|
203 |
+
gen_model = AutoModelForCausalLM.from_pretrained(gen_model_checkpoint, token=HF_TOKEN, device_map="auto",
|
204 |
+
load_in_8bit=True)
|
205 |
+
tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, token=HF_TOKEN, device_map="auto",
|
206 |
+
load_in_8bit=True)
|
207 |
+
else:
|
208 |
+
gen_model = AutoModelForCausalLM.from_pretrained(gen_model_checkpoint, token=HF_TOKEN)
|
209 |
+
tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, token=HF_TOKEN)
|
210 |
+
elif "gpt-3.5-turbo" in gen_model_checkpoint or "gpt-4" in gen_model_checkpoint:
|
211 |
+
gen_model = OpenAI(api_key=OPENAI_TOKEN)
|
212 |
+
elif "command-r" in gen_model_checkpoint:
|
213 |
+
gen_model = cohere.Client(COHERE_TOKEN)
|
214 |
+
|
215 |
+
return gen_model, tokenizer
|
216 |
+
|
agents/chatgpt.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#%%
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
|
5 |
+
#%%
|
6 |
+
import os
|
7 |
+
import json
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
from pathlib import Path
|
11 |
+
from transformers import set_seed
|
12 |
+
from textgames import GAME_NAMES, LEVEL_IDS, game_filename
|
13 |
+
from agents import run_with_agent
|
14 |
+
|
15 |
+
#%%
|
16 |
+
def set_all_seed(seed=42):
|
17 |
+
set_seed(seed)
|
18 |
+
np.random.seed(seed)
|
19 |
+
torch.manual_seed(seed)
|
20 |
+
torch.cuda.manual_seed_all(seed)
|
21 |
+
|
22 |
+
|
23 |
+
#%%
|
24 |
+
def _getenv_as_int(attr, default=None):
|
25 |
+
ret = os.getenv(attr, default)
|
26 |
+
return None if ret is None else int(ret)
|
27 |
+
|
28 |
+
|
29 |
+
GAME_ST, GAME_ED = _getenv_as_int("TG_GAME_ST", None), _getenv_as_int("TG_GAME_ED", None)
|
30 |
+
LVL_ST, LVL_ED = _getenv_as_int("TG_LEVEL_ST", None), _getenv_as_int("TG_LEVEL_ED", '3')
|
31 |
+
SID_ST, SID_ED = _getenv_as_int("TG_SID_ST", None), _getenv_as_int("TG_SID_ED", None)
|
32 |
+
N_TURNS = _getenv_as_int("TG_N_TURNS", 1)
|
33 |
+
ONE_SHOT = bool(int(os.getenv("TG_ONESHOT", "0")))
|
34 |
+
GPT_MODEL = os.getenv("TG_GPT_MODEL", "")
|
35 |
+
# MAX_NEW_TOKENS = _getenv_as_int("TG_MAX_NEW_TOKENS", 12000)
|
36 |
+
|
37 |
+
|
38 |
+
#%%
|
39 |
+
def preload_responses():
|
40 |
+
responses_all = dict()
|
41 |
+
for _turn in range(1, N_TURNS+1):
|
42 |
+
fp = os.getenv(
|
43 |
+
f"TG_GPT_OUTPUT_FILE_TURN_{_turn}",
|
44 |
+
(f"model_outputs/__runs__/chatgpt_4o_mini_results/raw/batch_output_chatgpt-{GPT_MODEL}_turn{_turn}"
|
45 |
+
f"{'.1s' if ONE_SHOT else '.zs'}.jsonl")
|
46 |
+
)
|
47 |
+
if not Path(fp).exists():
|
48 |
+
if _turn < N_TURNS:
|
49 |
+
print(f" batch_output turn {_turn} is not available. path: \"{fp}\"")
|
50 |
+
break
|
51 |
+
with open(fp, "r", encoding="utf8") as i:
|
52 |
+
data = [json.loads(line) for line in i]
|
53 |
+
for d in data:
|
54 |
+
sid, g = d['custom_id'].rsplit('-', 2)[-2:]
|
55 |
+
msg = d['response']['body']['choices'][0]['message']
|
56 |
+
responses_all.setdefault((g, _turn), dict())[sid] = msg['content']
|
57 |
+
responses_all[g, _turn][sid] = msg['content']
|
58 |
+
# assert msg['role'] == 'assistant'
|
59 |
+
# assert msg['refusal'] is None
|
60 |
+
# assert sum(len(v) for v in responses_all.values()) == 24000
|
61 |
+
return responses_all
|
62 |
+
RESPONSES_ALL = preload_responses()
|
63 |
+
print(f"len(RESPONSES_ALL) = {len(RESPONSES_ALL)}")
|
64 |
+
|
65 |
+
|
66 |
+
#%%
|
67 |
+
def gpt_postproc(response_txt_batch, *args, **kwargs):
|
68 |
+
response_txt_batch = [response_txt_batch]
|
69 |
+
ret = []
|
70 |
+
for response_txt in response_txt_batch:
|
71 |
+
if response_txt is None:
|
72 |
+
ret.append(response_txt)
|
73 |
+
continue
|
74 |
+
cur = None
|
75 |
+
for pat in [
|
76 |
+
re.compile(r'^```\n?([^`]*)\n?```'),
|
77 |
+
# re.compile(r'\*\*\"?([^\"*]*)\"?\*\*'),
|
78 |
+
re.compile(r'((.|\n)*)\n\nExplanation:\n'),
|
79 |
+
]:
|
80 |
+
match = pat.search(response_txt)
|
81 |
+
if match:
|
82 |
+
cur = match.group(1).strip()
|
83 |
+
# .replace(" ", "")
|
84 |
+
break
|
85 |
+
ret.append(cur if cur else response_txt)
|
86 |
+
return ret[0]
|
87 |
+
|
88 |
+
|
89 |
+
#%%
|
90 |
+
def get_gpt_response(texts, game_name, difficulty_level, turn, *args, **kwargs):
|
91 |
+
# global model, tokenizer
|
92 |
+
sid = kwargs['sid'] # sid must be fed as params
|
93 |
+
messages = [
|
94 |
+
({"role": "user", "content": text}
|
95 |
+
if i % 2 == 0 else
|
96 |
+
{"role": "assistant", "content": text})
|
97 |
+
for i, text in enumerate(texts)
|
98 |
+
]
|
99 |
+
|
100 |
+
response = None
|
101 |
+
responses_all = RESPONSES_ALL.get((f"{game_filename(game_name)}_{difficulty_level}", turn), {})
|
102 |
+
if responses_all:
|
103 |
+
response = responses_all[sid]
|
104 |
+
elif fp_next := os.getenv("TG_GPT_NEXTTURN_OUTPUT_FILE", None):
|
105 |
+
with open(fp_next, "a", encoding="utf8") as o:
|
106 |
+
o.write(json.dumps({
|
107 |
+
'custom_id': f"{sid}-{game_filename(game_name)}_{difficulty_level}",
|
108 |
+
"method": "POST", "url": "/v1/chat/completions",
|
109 |
+
"body": {
|
110 |
+
"model": "gpt-4o-mini-2024-07-18",
|
111 |
+
"max_completion_tokens": 200,
|
112 |
+
# "messages": [],
|
113 |
+
'messages': messages,
|
114 |
+
"seed": 42,
|
115 |
+
"temperature": 0,
|
116 |
+
}
|
117 |
+
}))
|
118 |
+
o.write("\n")
|
119 |
+
|
120 |
+
return response
|
121 |
+
|
122 |
+
|
123 |
+
#%%
|
124 |
+
if __name__ == "__main__":
|
125 |
+
fp_out = (f"model_outputs/__runs__/chatgpt_4o_mini_results/process/results_chatgpt-{GPT_MODEL}"
|
126 |
+
f"{'.1s' if ONE_SHOT else '.zs'}"
|
127 |
+
f"{'' if GAME_ST is None else f'.{GAME_ST}'}"
|
128 |
+
f"{'' if LVL_ST is None else f'.{LVL_ST}'}"
|
129 |
+
f".jsonl")
|
130 |
+
|
131 |
+
set_all_seed()
|
132 |
+
|
133 |
+
run_with_agent(
|
134 |
+
fp_out,
|
135 |
+
get_gpt_response,
|
136 |
+
gpt_postproc,
|
137 |
+
n_turns=N_TURNS,
|
138 |
+
game_names_list=GAME_NAMES[GAME_ST:GAME_ED],
|
139 |
+
level_ids_list=LEVEL_IDS[LVL_ST:LVL_ED],
|
140 |
+
sid_indices=(list(map(lambda r: f"session_{r:04}", range(SID_ST or 0, SID_ED or 1000)))
|
141 |
+
if SID_ST or SID_ED else None),
|
142 |
+
prepend_example=ONE_SHOT,
|
143 |
+
# remove_if_output_file_exist=False,
|
144 |
+
assistant_uses_raw_response=False,
|
145 |
+
)
|
agents/dsr1_distill.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#%%
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
|
5 |
+
#%%
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
|
9 |
+
from textgames import THE_GAMES, GAME_NAMES, LEVEL_IDS
|
10 |
+
from agents import run_with_agent
|
11 |
+
|
12 |
+
#%%
|
13 |
+
def set_all_seed(seed=42):
|
14 |
+
set_seed(seed)
|
15 |
+
np.random.seed(seed)
|
16 |
+
torch.manual_seed(seed)
|
17 |
+
torch.cuda.manual_seed_all(seed)
|
18 |
+
|
19 |
+
|
20 |
+
#%%
|
21 |
+
def _getenv_as_int(attr, default=None):
|
22 |
+
ret = os.getenv(attr, default)
|
23 |
+
return None if ret is None else int(ret)
|
24 |
+
|
25 |
+
|
26 |
+
GAME_ST, GAME_ED = _getenv_as_int("TG_GAME_ST", None), _getenv_as_int("TG_GAME_ED", None)
|
27 |
+
LVL_ST, LVL_ED = _getenv_as_int("TG_LEVEL_ST", None), _getenv_as_int("TG_LEVEL_ED", '3')
|
28 |
+
SID_ST, SID_ED = _getenv_as_int("TG_SID_ST", None), _getenv_as_int("TG_SID_ED", None)
|
29 |
+
N_TURNS = _getenv_as_int("TG_N_TURNS", 1)
|
30 |
+
ONE_SHOT = bool(int(os.getenv("TG_ONESHOT", "0")))
|
31 |
+
MAX_NEW_TOKENS = _getenv_as_int("TG_MAX_NEW_TOKENS", 12000)
|
32 |
+
DSR1_SIZE = os.getenv("TG_DSR1_SIZE", "14") # {1.5, 7, 8, 14, 32, 70}
|
33 |
+
DSR1_NAME = {
|
34 |
+
"1.5": "Qwen-1.5",
|
35 |
+
"7": "Qwen-7",
|
36 |
+
"8": "Llama-8",
|
37 |
+
"14": "Qwen-14",
|
38 |
+
"32": "Qwen-32",
|
39 |
+
"70": "Llama-70",
|
40 |
+
}
|
41 |
+
|
42 |
+
|
43 |
+
#%%
|
44 |
+
def dsr1_postproc(response_txt_batch, *args, **kwargs):
|
45 |
+
response_txt_batch = [response_txt_batch]
|
46 |
+
ret = []
|
47 |
+
for response_txt in response_txt_batch:
|
48 |
+
_match = None
|
49 |
+
for pat in [
|
50 |
+
re.compile(r'\\boxed\{([\s\S]*)}'),
|
51 |
+
re.compile(r'</think>\n([\s\S]*)$'),
|
52 |
+
re.compile(r'^```\n?([^`]*)\n?```'),
|
53 |
+
]:
|
54 |
+
matches = pat.search(response_txt)
|
55 |
+
if matches:
|
56 |
+
_match = matches.group(1).strip()
|
57 |
+
break
|
58 |
+
if _match is not None:
|
59 |
+
ret.append(_match)
|
60 |
+
else:
|
61 |
+
ret.append(response_txt[:256].strip() if response_txt else "")
|
62 |
+
return ret[0]
|
63 |
+
|
64 |
+
|
65 |
+
#%%
|
66 |
+
def get_dsr1_response(texts_batch, *args, **kwargs):
|
67 |
+
# global model, tokenizer
|
68 |
+
texts_batch = [texts_batch]
|
69 |
+
for texts in texts_batch:
|
70 |
+
if len(texts) > 1 and texts[1].startswith('Correct guess.'):
|
71 |
+
texts[1] = f"\\boxed{{{texts[1]}}}"
|
72 |
+
messages = [
|
73 |
+
[
|
74 |
+
{"role": "user",
|
75 |
+
"content": f"{text}\nPlease reason step by step, and put your final answer within \\boxed{{}} as plain text."}
|
76 |
+
if i % 2 == 0 else
|
77 |
+
{"role": "assistant", "content": {text}}
|
78 |
+
for i, text in enumerate(texts)
|
79 |
+
]
|
80 |
+
for texts in texts_batch
|
81 |
+
]
|
82 |
+
text_inputs = tokenizer.apply_chat_template(
|
83 |
+
messages,
|
84 |
+
tokenize=False,
|
85 |
+
add_generation_prompt=True
|
86 |
+
)
|
87 |
+
model_inputs = tokenizer(text_inputs, return_tensors="pt", add_special_tokens=False).to(model.device)
|
88 |
+
output_ids = model.generate(
|
89 |
+
**model_inputs,
|
90 |
+
max_new_tokens=MAX_NEW_TOKENS,
|
91 |
+
do_sample=False,
|
92 |
+
pad_token_id=tokenizer.eos_token_id,
|
93 |
+
)
|
94 |
+
generated_ids = [
|
95 |
+
_output_ids[len(input_ids):] for input_ids, _output_ids in zip(model_inputs.input_ids, output_ids)
|
96 |
+
]
|
97 |
+
response = [r.strip() for r in tokenizer.batch_decode(generated_ids, skip_special_tokens=True)]
|
98 |
+
return response[0]
|
99 |
+
|
100 |
+
|
101 |
+
#%%
|
102 |
+
# response = get_dsr1_response(texts)
|
103 |
+
# print(dsr1_postproc(response))
|
104 |
+
|
105 |
+
|
106 |
+
#%%
|
107 |
+
if __name__ == "__main__":
|
108 |
+
fp_out = (f"model_outputs/__runs__/results_deepseek-r1-distill-{DSR1_SIZE}b"
|
109 |
+
f"{'.1s' if ONE_SHOT else '.zs'}"
|
110 |
+
f"{'' if GAME_ST is None else f'.{GAME_ST}'}"
|
111 |
+
f"{'' if LVL_ST is None else f'.{LVL_ST}'}"
|
112 |
+
f".jsonl")
|
113 |
+
|
114 |
+
set_all_seed()
|
115 |
+
model_name = f"deepseek-ai/DeepSeek-R1-Distill-{DSR1_NAME[DSR1_SIZE]}B"
|
116 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
117 |
+
model = AutoModelForCausalLM.from_pretrained(
|
118 |
+
model_name,
|
119 |
+
device_map="auto",
|
120 |
+
torch_dtype="auto",
|
121 |
+
)
|
122 |
+
model.generation_config.temperature = None
|
123 |
+
model.generation_config.top_k = None
|
124 |
+
model.generation_config.top_p = None
|
125 |
+
|
126 |
+
run_with_agent(
|
127 |
+
fp_out,
|
128 |
+
get_dsr1_response,
|
129 |
+
dsr1_postproc,
|
130 |
+
n_turns=N_TURNS,
|
131 |
+
game_names_list=GAME_NAMES[GAME_ST:GAME_ED],
|
132 |
+
level_ids_list=LEVEL_IDS[LVL_ST:LVL_ED],
|
133 |
+
sid_indices=(list(map(lambda r: f"session_{r:04}", range(SID_ST or 0, SID_ED or 1000)))
|
134 |
+
if SID_ST or SID_ED else None),
|
135 |
+
prepend_example=ONE_SHOT,
|
136 |
+
# remove_if_output_file_exist=False,
|
137 |
+
assistant_uses_raw_response=False,
|
138 |
+
)
|
agents/gemma_2_9b_it.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#%%
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
|
5 |
+
#%%
|
6 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
7 |
+
from textgames import THE_GAMES, GAME_NAMES, LEVEL_IDS
|
8 |
+
from agents import run_with_agent
|
9 |
+
|
10 |
+
|
11 |
+
#%%
|
12 |
+
def _getenv_as_int(attr, default=None):
|
13 |
+
ret = os.getenv(attr, default)
|
14 |
+
return None if ret is None else int(ret)
|
15 |
+
|
16 |
+
|
17 |
+
GAME_ST, GAME_ED = _getenv_as_int("TG_GAME_ST", None), _getenv_as_int("TG_GAME_ED", None)
|
18 |
+
LVL_ST, LVL_ED = _getenv_as_int("TG_LEVEL_ST", None), _getenv_as_int("TG_LEVEL_ED", '3')
|
19 |
+
SID_ST, SID_ED = _getenv_as_int("TG_SID_ST", None), _getenv_as_int("TG_SID_ED", None)
|
20 |
+
N_TURNS = _getenv_as_int("TG_N_TURNS", 3)
|
21 |
+
ONE_SHOT = bool(int(os.getenv("TG_ONESHOT", "0")))
|
22 |
+
GEMMA_SIZE = int(os.getenv("TG_GEMMA_SIZE", "9")) # {3, 9, 27}
|
23 |
+
|
24 |
+
|
25 |
+
#%%
|
26 |
+
def gemma_postproc(response_txt, game_name, difficulty_level, *args, **kwargs):
|
27 |
+
# if game_name in [THE_GAMES[i] for i in ["1", "7"]]: # crossword
|
28 |
+
pat = re.compile(r'^```\n?([^`]*)\n?```')
|
29 |
+
match = pat.search(response_txt)
|
30 |
+
if match:
|
31 |
+
return match.group(1).strip().replace(" ", "")
|
32 |
+
|
33 |
+
# elif game_name == THE_GAMES["6"]: # anagram
|
34 |
+
pat = re.compile(r'\*\*\"?([^\"*]*)\"?\*\*')
|
35 |
+
match = pat.search(response_txt)
|
36 |
+
if match:
|
37 |
+
return match.group(1).strip()
|
38 |
+
|
39 |
+
return response_txt or ""
|
40 |
+
|
41 |
+
|
42 |
+
#%%
|
43 |
+
def get_gemma_response(texts, game_name, difficulty_level, turn, *args, **kwargs):
|
44 |
+
# global gen_model, tokenizer
|
45 |
+
messages = [
|
46 |
+
{"role": ("model" if i % 2 else "user"), "content": text}
|
47 |
+
for i, text in enumerate(texts)
|
48 |
+
]
|
49 |
+
|
50 |
+
input_ids = tokenizer.apply_chat_template(
|
51 |
+
messages,
|
52 |
+
add_generation_prompt=True,
|
53 |
+
return_tensors="pt"
|
54 |
+
).to(gen_model.device)
|
55 |
+
|
56 |
+
terminators = [
|
57 |
+
tokenizer.eos_token_id,
|
58 |
+
tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
59 |
+
]
|
60 |
+
|
61 |
+
gen_model.generation_config.temperature = None
|
62 |
+
outputs = gen_model.generate(
|
63 |
+
input_ids,
|
64 |
+
max_new_tokens=100,
|
65 |
+
eos_token_id=terminators,
|
66 |
+
do_sample=False,
|
67 |
+
# temperature=.0,
|
68 |
+
# top_p=1,
|
69 |
+
)
|
70 |
+
|
71 |
+
response = outputs[0][input_ids.shape[-1]:]
|
72 |
+
return tokenizer.decode(response, skip_special_tokens=True).strip()
|
73 |
+
|
74 |
+
|
75 |
+
#%%
|
76 |
+
if __name__ == "__main__":
|
77 |
+
fp_out = (f"model_outputs/results_gemma-2-{GEMMA_SIZE}b-it"
|
78 |
+
f"{'.1s' if ONE_SHOT else '.zs'}"
|
79 |
+
f"{'' if GAME_ST is None else f'.{GAME_ST}'}"
|
80 |
+
f"{'' if LVL_ST is None else f'.{LVL_ST}'}"
|
81 |
+
f".jsonl")
|
82 |
+
gen_model_checkpoint = f"google/gemma-2-{GEMMA_SIZE}b-it"
|
83 |
+
|
84 |
+
quantize = True
|
85 |
+
_kwargs = {
|
86 |
+
"device_map": "auto",
|
87 |
+
} if quantize else {}
|
88 |
+
|
89 |
+
gen_model = AutoModelForCausalLM.from_pretrained(gen_model_checkpoint, **_kwargs)
|
90 |
+
tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, **_kwargs)
|
91 |
+
print(f" > model.dtype: {gen_model.dtype}")
|
92 |
+
|
93 |
+
run_with_agent(
|
94 |
+
fp_out,
|
95 |
+
get_gemma_response,
|
96 |
+
gemma_postproc,
|
97 |
+
n_turns=N_TURNS,
|
98 |
+
game_names_list=GAME_NAMES[GAME_ST:GAME_ED],
|
99 |
+
level_ids_list=LEVEL_IDS[LVL_ST:LVL_ED],
|
100 |
+
sid_indices=(list(map(lambda r: f"session_{r:04}", range(SID_ST or 0, SID_ED or 1000)))
|
101 |
+
if SID_ST or SID_ED else None),
|
102 |
+
prepend_example=ONE_SHOT,
|
103 |
+
# remove_if_output_file_exist=False,
|
104 |
+
)
|
agents/llama3.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#%%
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
|
5 |
+
#%%
|
6 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
7 |
+
from textgames import THE_GAMES, GAME_NAMES, LEVEL_IDS
|
8 |
+
from agents import run_with_agent
|
9 |
+
|
10 |
+
|
11 |
+
#%%
|
12 |
+
def _getenv_as_int(attr, default=None):
|
13 |
+
ret = os.getenv(attr, default)
|
14 |
+
return None if ret is None else int(ret)
|
15 |
+
|
16 |
+
|
17 |
+
GAME_ST, GAME_ED = _getenv_as_int("TG_GAME_ST", None), _getenv_as_int("TG_GAME_ED", None)
|
18 |
+
LVL_ST, LVL_ED = _getenv_as_int("TG_LEVEL_ST", None), _getenv_as_int("TG_LEVEL_ED", '3')
|
19 |
+
SID_ST, SID_ED = _getenv_as_int("TG_SID_ST", None), _getenv_as_int("TG_SID_ED", None)
|
20 |
+
N_TURNS = _getenv_as_int("TG_N_TURNS", 3)
|
21 |
+
ONE_SHOT = bool(int(os.getenv("TG_ONESHOT", "0")))
|
22 |
+
LLAMA_SIZE = os.getenv("TG_LLAMA_SIZE", "1-8")
|
23 |
+
|
24 |
+
|
25 |
+
#%%
|
26 |
+
def llama_postproc(response_txt, *args, **kwargs):
|
27 |
+
# # if game_name in [THE_GAMES[i] for i in ["1", "7"]]: # crossword
|
28 |
+
# pat = re.compile(r'^```\n?([^`]*)\n?```')
|
29 |
+
# match = pat.search(response_txt)
|
30 |
+
# if match:
|
31 |
+
# return match.group(1).strip().replace(" ", "")
|
32 |
+
#
|
33 |
+
# # elif game_name == THE_GAMES["6"]: # anagram
|
34 |
+
# pat = re.compile(r'\*\*\"?([^\"*]*)\"?\*\*')
|
35 |
+
# match = pat.search(response_txt)
|
36 |
+
# if match:
|
37 |
+
# return match.group(1).strip()
|
38 |
+
return response_txt or ""
|
39 |
+
|
40 |
+
|
41 |
+
#%%
|
42 |
+
def get_llama_response(texts, *args, **kwargs):
|
43 |
+
# global model, tokenizer
|
44 |
+
|
45 |
+
messages = [
|
46 |
+
# {"role": "system", "content": "You are a bot that responds to weather queries."},
|
47 |
+
*[{"role": ("assistant" if i % 2 else "user"), "content": text} for i, text in enumerate(texts)]
|
48 |
+
]
|
49 |
+
|
50 |
+
text_inputs = tokenizer.apply_chat_template(
|
51 |
+
messages,
|
52 |
+
tokenize=False,
|
53 |
+
add_generation_prompt=True,
|
54 |
+
)
|
55 |
+
model_inputs = tokenizer([text_inputs], return_tensors="pt").to(model.device)
|
56 |
+
|
57 |
+
model.generation_config.do_sample = False
|
58 |
+
model.generation_config.temperature = None
|
59 |
+
model.generation_config.top_k = None
|
60 |
+
model.generation_config.top_p = None
|
61 |
+
generated_ids = model.generate(
|
62 |
+
**model_inputs,
|
63 |
+
max_new_tokens=128,
|
64 |
+
do_sample=False,
|
65 |
+
pad_token_id=tokenizer.eos_token_id,
|
66 |
+
)
|
67 |
+
generated_ids = [
|
68 |
+
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
69 |
+
]
|
70 |
+
|
71 |
+
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
72 |
+
return response.strip()
|
73 |
+
|
74 |
+
|
75 |
+
#%%
|
76 |
+
if __name__ == "__main__":
|
77 |
+
fp_out = (f"model_outputs/__runs__/results_llama-3.{LLAMA_SIZE}b-instruct"
|
78 |
+
f"{'.1s' if ONE_SHOT else '.zs'}"
|
79 |
+
f"{'' if GAME_ST is None else f'.{GAME_ST}'}"
|
80 |
+
f"{'' if LVL_ST is None else f'.{LVL_ST}'}"
|
81 |
+
f".jsonl")
|
82 |
+
|
83 |
+
model_name = f"meta-llama/Llama-3.{LLAMA_SIZE}B-Instruct"
|
84 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
85 |
+
model = AutoModelForCausalLM.from_pretrained(
|
86 |
+
model_name,
|
87 |
+
device_map="auto",
|
88 |
+
torch_dtype="bfloat16",
|
89 |
+
)
|
90 |
+
|
91 |
+
run_with_agent(
|
92 |
+
fp_out,
|
93 |
+
get_llama_response,
|
94 |
+
llama_postproc,
|
95 |
+
n_turns=N_TURNS,
|
96 |
+
game_names_list=GAME_NAMES[GAME_ST:GAME_ED],
|
97 |
+
level_ids_list=LEVEL_IDS[LVL_ST:LVL_ED],
|
98 |
+
sid_indices=(list(map(lambda r: f"session_{r:04}", range(SID_ST or 0, SID_ED or 1000)))
|
99 |
+
if SID_ST or SID_ED else None),
|
100 |
+
prepend_example=ONE_SHOT,
|
101 |
+
# remove_if_output_file_exist=False,
|
102 |
+
)
|
agents/qwen2_5_7b_instruct.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#%%
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
|
5 |
+
#%%
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
|
9 |
+
from textgames import GAME_NAMES, LEVEL_IDS
|
10 |
+
from agents import run_with_agent
|
11 |
+
|
12 |
+
|
13 |
+
#%%
|
14 |
+
def set_all_seed(seed=42):
|
15 |
+
set_seed(seed)
|
16 |
+
np.random.seed(seed)
|
17 |
+
torch.manual_seed(seed)
|
18 |
+
torch.cuda.manual_seed_all(seed)
|
19 |
+
|
20 |
+
|
21 |
+
#%%
|
22 |
+
def _getenv_as_int(attr, default=None):
|
23 |
+
ret = os.getenv(attr, default)
|
24 |
+
return None if ret is None else int(ret)
|
25 |
+
|
26 |
+
|
27 |
+
GAME_ST, GAME_ED = _getenv_as_int("TG_GAME_ST", None), _getenv_as_int("TG_GAME_ED", None)
|
28 |
+
LVL_ST, LVL_ED = _getenv_as_int("TG_LEVEL_ST", None), _getenv_as_int("TG_LEVEL_ED", '3')
|
29 |
+
SID_ST, SID_ED = _getenv_as_int("TG_SID_ST", None), _getenv_as_int("TG_SID_ED", None)
|
30 |
+
N_TURNS = _getenv_as_int("TG_N_TURNS", 3)
|
31 |
+
ONE_SHOT = bool(int(os.getenv("TG_ONESHOT", "0")))
|
32 |
+
QWEN_SIZE = int(os.getenv("TG_QWEN_SIZE", "32")) # {3, 7, 14, 32, 72} unsupported: {0.5, 1.5}
|
33 |
+
|
34 |
+
|
35 |
+
#%%
|
36 |
+
def qwen_postproc(response_txt, game_name, difficulty_level, *args, **kwargs):
|
37 |
+
# # if game_name in [THE_GAMES[i] for i in ["1", "7"]]: # crossword
|
38 |
+
# pat = re.compile(r'^```\n?([^`]*)\n?```')
|
39 |
+
# match = pat.search(response_txt)
|
40 |
+
# if match:
|
41 |
+
# return match.group(1).strip().replace(" ", "")
|
42 |
+
#
|
43 |
+
# # elif game_name == THE_GAMES["6"]: # anagram
|
44 |
+
# pat = re.compile(r'\*\*\"?([^\"*]*)\"?\*\*')
|
45 |
+
# match = pat.search(response_txt)
|
46 |
+
# if match:
|
47 |
+
# return match.group(1).strip()
|
48 |
+
return response_txt or ""
|
49 |
+
|
50 |
+
|
51 |
+
#%%
|
52 |
+
def get_qwen_response(texts_batch, game_name, difficulty_level, turn, *args, **kwargs):
|
53 |
+
# global model, tokenizer
|
54 |
+
texts_batch = [texts_batch] # currently we do not support batch
|
55 |
+
messages = [[
|
56 |
+
{"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
|
57 |
+
*[{"role": ("assistant" if i % 2 else "user"), "content": text} for i, text in enumerate(texts)]
|
58 |
+
] for texts in texts_batch ]
|
59 |
+
|
60 |
+
text_inputs = tokenizer.apply_chat_template(
|
61 |
+
messages,
|
62 |
+
tokenize=False,
|
63 |
+
add_generation_prompt=True
|
64 |
+
)
|
65 |
+
model_inputs = tokenizer([text_inputs], return_tensors="pt").to(model.device)
|
66 |
+
|
67 |
+
model.generation_config.temperature = None
|
68 |
+
model.generation_config.top_k = None
|
69 |
+
model.generation_config.top_p = None
|
70 |
+
generated_ids = model.generate(
|
71 |
+
**model_inputs,
|
72 |
+
max_new_tokens=128,
|
73 |
+
do_sample=False,
|
74 |
+
)
|
75 |
+
generated_ids = [
|
76 |
+
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
77 |
+
]
|
78 |
+
|
79 |
+
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
80 |
+
return response.strip()
|
81 |
+
|
82 |
+
|
83 |
+
#%%
|
84 |
+
if __name__ == "__main__":
|
85 |
+
fp_out = (f"model_outputs/__runs__/results_qwen2-5-{QWEN_SIZE}b-instruct"
|
86 |
+
f"{'.1s' if ONE_SHOT else '.zs'}"
|
87 |
+
f"{'' if GAME_ST is None else f'.{GAME_ST}'}"
|
88 |
+
f"{'' if LVL_ST is None else f'.{LVL_ST}'}"
|
89 |
+
f".jsonl")
|
90 |
+
|
91 |
+
set_all_seed()
|
92 |
+
model_name = f"Qwen/Qwen2.5-{QWEN_SIZE}B-Instruct"
|
93 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
94 |
+
model = AutoModelForCausalLM.from_pretrained(
|
95 |
+
model_name,
|
96 |
+
device_map="auto",
|
97 |
+
torch_dtype="auto",
|
98 |
+
)
|
99 |
+
print(f" > model.dtype: {model.dtype}")
|
100 |
+
|
101 |
+
run_with_agent(
|
102 |
+
fp_out,
|
103 |
+
get_qwen_response,
|
104 |
+
qwen_postproc,
|
105 |
+
n_turns=N_TURNS,
|
106 |
+
game_names_list=GAME_NAMES[GAME_ST:GAME_ED],
|
107 |
+
level_ids_list=LEVEL_IDS[LVL_ST:LVL_ED],
|
108 |
+
sid_indices=(list(map(lambda r: f"session_{r:04}", range(SID_ST or 0, SID_ED or 1000)))
|
109 |
+
if SID_ST or SID_ED else None),
|
110 |
+
prepend_example=ONE_SHOT,
|
111 |
+
# remove_if_output_file_exist=False,
|
112 |
+
)
|
agents/qwen2_5_math.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#%%
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
|
5 |
+
#%%
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed, BitsAndBytesConfig
|
9 |
+
from textgames import GAME_NAMES, LEVEL_IDS
|
10 |
+
from agents import run_with_agent
|
11 |
+
|
12 |
+
|
13 |
+
#%%
|
14 |
+
def set_all_seed(seed=42):
|
15 |
+
set_seed(seed)
|
16 |
+
np.random.seed(seed)
|
17 |
+
torch.manual_seed(seed)
|
18 |
+
torch.cuda.manual_seed_all(seed)
|
19 |
+
|
20 |
+
|
21 |
+
#%%
|
22 |
+
def _getenv_as_int(attr, default=None):
|
23 |
+
ret = os.getenv(attr, default)
|
24 |
+
return None if ret is None else int(ret)
|
25 |
+
|
26 |
+
|
27 |
+
GAME_ST, GAME_ED = _getenv_as_int("TG_GAME_ST", None), _getenv_as_int("TG_GAME_ED", None)
|
28 |
+
LVL_ST, LVL_ED = _getenv_as_int("TG_LEVEL_ST", None), _getenv_as_int("TG_LEVEL_ED", '3')
|
29 |
+
SID_ST, SID_ED = _getenv_as_int("TG_SID_ST", None), _getenv_as_int("TG_SID_ED", None)
|
30 |
+
N_TURNS = _getenv_as_int("TG_N_TURNS", 3)
|
31 |
+
ONE_SHOT = bool(int(os.getenv("TG_ONESHOT", "0")))
|
32 |
+
# MAX_NEW_TOKENS = _getenv_as_int("TG_MAX_NEW_TOKENS", 4096)
|
33 |
+
QWEN_MATH_SIZE = os.getenv("TG_QWEN_MATH_SIZE", "7") # {1.5, 7, 72}
|
34 |
+
QUANTIZE = _getenv_as_int("TG_QUANTIZE", 4)
|
35 |
+
|
36 |
+
|
37 |
+
#%%
|
38 |
+
def qwenmath_postproc(response_txt_batch, *args, **kwargs):
|
39 |
+
response_txt_batch = [response_txt_batch]
|
40 |
+
ret = []
|
41 |
+
for response_txt in response_txt_batch:
|
42 |
+
_match = None
|
43 |
+
for pat in [
|
44 |
+
re.compile(r'\\boxed\{([\s\S]*)}'),
|
45 |
+
re.compile(r'^```\n?([^`]*)\n?```'),
|
46 |
+
# re.compile(r'</think>\n([\s\S]*)$'),
|
47 |
+
]:
|
48 |
+
matches = pat.search(response_txt)
|
49 |
+
if matches:
|
50 |
+
_match = matches.group(1).strip()
|
51 |
+
break
|
52 |
+
if _match is not None:
|
53 |
+
ret.append(_match)
|
54 |
+
else:
|
55 |
+
ret.append(response_txt if response_txt else "")
|
56 |
+
return ret[0]
|
57 |
+
|
58 |
+
|
59 |
+
#%%
|
60 |
+
def get_qwenmath_response(texts_batch, *args, **kwargs):
|
61 |
+
# global model, tokenizer
|
62 |
+
texts_batch = [texts_batch]
|
63 |
+
for texts in texts_batch:
|
64 |
+
if (len(texts) > 1) and texts[2].startswith('Correct guess.'): # assert len(texts) % 2 == 1
|
65 |
+
texts[1] = f"\\boxed{{{texts[1]}}}"
|
66 |
+
messages = [
|
67 |
+
[
|
68 |
+
{"role": "system", "content": "Please reason step by step, and put your final answer within \\boxed{} as plain text."},
|
69 |
+
*[{"role": ("user" if i % 2 == 0 else "assistant"), "content": text} for i, text in enumerate(texts)],
|
70 |
+
]
|
71 |
+
for texts in texts_batch
|
72 |
+
]
|
73 |
+
# print(f"\n{messages[0]}", end="\n=====\n\n")
|
74 |
+
|
75 |
+
text_inputs = tokenizer.apply_chat_template(
|
76 |
+
messages,
|
77 |
+
tokenize=False,
|
78 |
+
add_generation_prompt=True
|
79 |
+
)
|
80 |
+
model_inputs = tokenizer(text_inputs, return_tensors="pt", add_special_tokens=False).to(model.device)
|
81 |
+
|
82 |
+
generated_ids = model.generate(
|
83 |
+
**model_inputs,
|
84 |
+
max_new_tokens=512,
|
85 |
+
do_sample=False,
|
86 |
+
)
|
87 |
+
generated_ids = [
|
88 |
+
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
89 |
+
]
|
90 |
+
|
91 |
+
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
92 |
+
return response.strip()
|
93 |
+
|
94 |
+
|
95 |
+
#%%
|
96 |
+
if __name__ == "__main__":
|
97 |
+
fp_out = (f"model_outputs/__runs__/results_qwen2-5-math-{QWEN_MATH_SIZE}b-instruct_{QUANTIZE}bit"
|
98 |
+
f"{'.1s' if ONE_SHOT else '.zs'}"
|
99 |
+
f"{'' if GAME_ST is None else f'.{GAME_ST}'}"
|
100 |
+
f"{'' if LVL_ST is None else f'.{LVL_ST}'}"
|
101 |
+
f".jsonl")
|
102 |
+
|
103 |
+
set_all_seed()
|
104 |
+
if QWEN_MATH_SIZE in ['72'] and QUANTIZE < 16:
|
105 |
+
_additional_kwargs = {
|
106 |
+
"quantization_config": (
|
107 |
+
BitsAndBytesConfig(load_in_8bit=True)
|
108 |
+
if QUANTIZE == 8 else
|
109 |
+
BitsAndBytesConfig(load_in_4bit=True)
|
110 |
+
),
|
111 |
+
"low_cpu_mem_usage": True,
|
112 |
+
}
|
113 |
+
else:
|
114 |
+
_additional_kwargs = {"device_map": "auto"}
|
115 |
+
|
116 |
+
model_name = f"Qwen/Qwen2.5-Math-{QWEN_MATH_SIZE}B-Instruct"
|
117 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
118 |
+
model = AutoModelForCausalLM.from_pretrained(
|
119 |
+
model_name,
|
120 |
+
torch_dtype="auto",
|
121 |
+
**_additional_kwargs,
|
122 |
+
)
|
123 |
+
print(f" > model.dtype: {model.dtype}")
|
124 |
+
|
125 |
+
run_with_agent(
|
126 |
+
fp_out,
|
127 |
+
get_qwenmath_response,
|
128 |
+
qwenmath_postproc,
|
129 |
+
n_turns=N_TURNS,
|
130 |
+
game_names_list=GAME_NAMES[GAME_ST:GAME_ED],
|
131 |
+
level_ids_list=LEVEL_IDS[LVL_ST:LVL_ED],
|
132 |
+
sid_indices=(list(map(lambda r: f"session_{r:04}", range(SID_ST or 0, SID_ED or 1000)))
|
133 |
+
if SID_ST or SID_ED else None),
|
134 |
+
prepend_example=ONE_SHOT,
|
135 |
+
# remove_if_output_file_exist=False,
|
136 |
+
assistant_uses_raw_response=True,
|
137 |
+
)
|
agents/runner.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#%%
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
|
5 |
+
from textgames import GAME_NAMES, LEVEL_IDS, game_filename, _game_class_from_name
|
6 |
+
|
7 |
+
from tqdm import tqdm
|
8 |
+
from itertools import product
|
9 |
+
from pathlib import Path
|
10 |
+
from typing import Union, Callable
|
11 |
+
|
12 |
+
|
13 |
+
def response_postprocess(response_txt, game_name, difficulty_level):
|
14 |
+
return response_txt or ""
|
15 |
+
|
16 |
+
|
17 |
+
def run_with_agent(fp_out: Union[str, Path],
|
18 |
+
get_response: Callable,
|
19 |
+
get_postprocess: Callable = response_postprocess,
|
20 |
+
n_turns=3,
|
21 |
+
game_names_list=GAME_NAMES,
|
22 |
+
level_ids_list=LEVEL_IDS[:3],
|
23 |
+
sid_indices=None, # sid_index_range=range(0, 1000),
|
24 |
+
remove_if_output_file_exist=True,
|
25 |
+
prepend_example=False,
|
26 |
+
assistant_uses_raw_response=True,
|
27 |
+
) -> None:
|
28 |
+
os.makedirs(os.path.dirname(os.path.abspath(fp_out)), exist_ok=True)
|
29 |
+
print(fp_out)
|
30 |
+
if remove_if_output_file_exist:
|
31 |
+
with open(fp_out, "wb"):
|
32 |
+
pass
|
33 |
+
|
34 |
+
for game_name, difficulty_level in product(game_names_list, level_ids_list):
|
35 |
+
game_str = f"{game_filename(game_name)}_{difficulty_level}"
|
36 |
+
game_cls = _game_class_from_name(game_name)
|
37 |
+
with open(f"problemsets/{game_str}.json", "r", encoding="utf8") as f:
|
38 |
+
sid_prompt_dict = json.load(f)
|
39 |
+
if sid_indices is not None:
|
40 |
+
sid_prompt_dict = {k: sid_prompt_dict[k] for k in sid_indices}
|
41 |
+
|
42 |
+
correct_cnt, exception_cnt = 0, 0
|
43 |
+
for sid, prompt in tqdm(sid_prompt_dict.items(), desc=game_str, total=len(sid_prompt_dict)):
|
44 |
+
cur_game = game_cls()
|
45 |
+
cur_game.load_game(prompt)
|
46 |
+
if prepend_example:
|
47 |
+
texts = [*cur_game.example(), f"Correct guess. Now let's try another example.\n{cur_game.get_prompt()}"]
|
48 |
+
else:
|
49 |
+
texts = [cur_game.get_prompt()]
|
50 |
+
for turn in range(1, n_turns + 1):
|
51 |
+
response_raw, response, e = None, None, None
|
52 |
+
solved, val_msg = False, None
|
53 |
+
try:
|
54 |
+
response_raw = get_response(texts, game_name, difficulty_level, turn, sid=sid)
|
55 |
+
response = get_postprocess(response_raw, game_name, difficulty_level)
|
56 |
+
texts.append(response_raw if assistant_uses_raw_response else response)
|
57 |
+
solved, val_msg = (False, None) if response is None else cur_game.validate(response)
|
58 |
+
texts.append(
|
59 |
+
f"Bad guess (Wrong Answer).\n{val_msg}\nPlease try again and print the answer only."
|
60 |
+
if not solved else "Correct guess."
|
61 |
+
)
|
62 |
+
except Exception as _e:
|
63 |
+
e = _e
|
64 |
+
# print(e)
|
65 |
+
# assert False, {"texts": texts, "response": response_raw,
|
66 |
+
# "args": (n_turns, game_names_list, remove_if_output_file_exist, prepend_example, assistant_uses_raw_response)}
|
67 |
+
with open(fp_out, "a", encoding="utf8") as o:
|
68 |
+
json.dump({
|
69 |
+
"game": game_str,
|
70 |
+
"session": sid,
|
71 |
+
"turn": turn,
|
72 |
+
"response": response,
|
73 |
+
"solved": solved,
|
74 |
+
"val_msg": val_msg,
|
75 |
+
"response_raw": response_raw,
|
76 |
+
"error": repr(e) if e else e,
|
77 |
+
}, o, ensure_ascii=False)
|
78 |
+
o.write("\n")
|
79 |
+
if solved:
|
80 |
+
correct_cnt += 1
|
81 |
+
if e:
|
82 |
+
exception_cnt += 1
|
83 |
+
if solved or e:
|
84 |
+
break
|
85 |
+
|
86 |
+
print(f"{game_filename(game_name)}_-_{difficulty_level}")
|
87 |
+
print(f" > Correct: {correct_cnt:>6,} ({correct_cnt / len(sid_prompt_dict):.2%})")
|
88 |
+
print(f" > Error : {exception_cnt:>6,} ({exception_cnt / len(sid_prompt_dict):.2%})")
|
89 |
+
|
play_gradio.py
CHANGED
@@ -51,7 +51,7 @@ def greet(request: gr.Request):
|
|
51 |
|
52 |
#%%
|
53 |
with gr.Blocks(title="TextGames", css=css, delete_cache=(3600, 3600)) as demo:
|
54 |
-
((m, logout_btn, solved_games_df, game_radio, level_radio, new_game_btn, render_toggle),
|
55 |
(session_state, is_solved, solved_games, user_state, uid_state),
|
56 |
) = declare_components(demo, greet)
|
57 |
|
@@ -64,7 +64,7 @@ demo.launch(
|
|
64 |
auth=file_based_auth,
|
65 |
favicon_path=favicon_path if os.path.exists(favicon_path) else None,
|
66 |
share=True,
|
67 |
-
|
68 |
)
|
69 |
|
70 |
|
|
|
51 |
|
52 |
#%%
|
53 |
with gr.Blocks(title="TextGames", css=css, delete_cache=(3600, 3600)) as demo:
|
54 |
+
((m, logout_btn, solved_games_df, game_radio, level_radio, new_game_btn, render_toggle, reset_sid_btn),
|
55 |
(session_state, is_solved, solved_games, user_state, uid_state),
|
56 |
) = declare_components(demo, greet)
|
57 |
|
|
|
64 |
auth=file_based_auth,
|
65 |
favicon_path=favicon_path if os.path.exists(favicon_path) else None,
|
66 |
share=True,
|
67 |
+
show_api=False,
|
68 |
)
|
69 |
|
70 |
|
play_helper.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
# %%
|
2 |
import os
|
3 |
import time
|
|
|
4 |
import pandas as pd
|
5 |
import gradio as gr
|
6 |
import hashlib
|
@@ -19,19 +20,27 @@ from googleapiclient.discovery import build
|
|
19 |
from googleapiclient.errors import HttpError
|
20 |
from googleapiclient.http import MediaFileUpload, MediaIoBaseDownload
|
21 |
|
|
|
|
|
|
|
22 |
|
23 |
# %%
|
24 |
-
def declare_components(demo, greet):
|
25 |
with gr.Row():
|
26 |
with gr.Column(scale=1):
|
27 |
m = gr.Markdown("Welcome to TextGames!", elem_id="md-greeting")
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
29 |
with gr.Column(scale=2):
|
30 |
-
solved_games_df = gr.DataFrame(headers=[g.split('\t', 1)[0] for g in GAME_NAMES], label="
|
31 |
-
interactive=False, elem_id="df-solved-games")
|
32 |
-
|
33 |
-
|
34 |
-
new_game_btn = gr.Button("Start Game", elem_id="btn-start-game")
|
35 |
render_toggle = gr.Checkbox(False, visible=False, interactive=False)
|
36 |
|
37 |
# cur_game_start = gr.BrowserState()
|
@@ -41,9 +50,12 @@ def declare_components(demo, greet):
|
|
41 |
user_state = gr.State()
|
42 |
uid_state = gr.State()
|
43 |
|
|
|
|
|
|
|
44 |
session_state.change(
|
45 |
-
lambda s: session_state_change_fn(s, 2, 0,
|
46 |
-
[session_state], [game_radio, level_radio, new_game_btn, logout_btn], js=js_remove_input_helper,
|
47 |
)
|
48 |
new_game_btn.click(check_to_start_new_game, [game_radio, level_radio, user_state, uid_state], [session_state])
|
49 |
solved_games.change(solved_games_change_fn, solved_games, solved_games_df)
|
@@ -54,13 +66,15 @@ def declare_components(demo, greet):
|
|
54 |
).then(
|
55 |
lambda: gr.update(interactive=False), None, [new_game_btn],
|
56 |
).then(
|
57 |
-
check_played_game, [solved_games,
|
58 |
).then(
|
59 |
-
lambda: gr.update(interactive=True)
|
|
|
|
|
60 |
)
|
61 |
|
62 |
return (
|
63 |
-
(m, logout_btn, solved_games_df, game_radio, level_radio, new_game_btn, render_toggle),
|
64 |
(session_state, is_solved, solved_games, user_state, uid_state),
|
65 |
)
|
66 |
|
@@ -489,7 +503,8 @@ def _is_checksum_same(fp_out, matches=None, mime_type="application/octet-stream"
|
|
489 |
matches = _files.list(
|
490 |
q=f"'{_folder_id}' in parents and mimeType='{mime_type}' and name = '{fp_out.rsplit('/', 1)[-1]}'",
|
491 |
fields=f"files(name, id, {_cksm_methods_str})",
|
492 |
-
).execute()
|
|
|
493 |
if not os.path.exists(fp_out):
|
494 |
return None, None, matches
|
495 |
with open(fp_out, "rb") as o:
|
@@ -502,9 +517,9 @@ def _is_checksum_same(fp_out, matches=None, mime_type="application/octet-stream"
|
|
502 |
|
503 |
|
504 |
# %%
|
505 |
-
def upload_to_drive(fp_out, matches=None, mime_type="application/octet-stream", compare_checksum=True):
|
506 |
if compare_checksum:
|
507 |
-
same_checksum, _,
|
508 |
# same_checksum, _, _ = _is_checksum_same(
|
509 |
# fp_out, **{k: v for k, v in [('matches', matches), ('mime_type', mime_type)] if v})
|
510 |
if same_checksum:
|
@@ -513,7 +528,11 @@ def upload_to_drive(fp_out, matches=None, mime_type="application/octet-stream",
|
|
513 |
file_metadata = {"name": fn, "parents": [_folder_id]}
|
514 |
media = MediaFileUpload(fp_out)
|
515 |
try:
|
516 |
-
|
|
|
|
|
|
|
|
|
517 |
except HttpError as error:
|
518 |
msg = f"Failed to upload the file, error: {error}"
|
519 |
print(msg)
|
@@ -547,7 +566,7 @@ def download_from_drive(fp_out, matches=None, mime_type="application/octet-strea
|
|
547 |
|
548 |
# %%
|
549 |
def start_new_game(game_name, level, session_state_component, is_solved_component, solved_games_component,
|
550 |
-
user=None, show_timer=False, uid=None):
|
551 |
# cur_game_id = GAME_IDS[GAME_NAMES.index(game_name)]
|
552 |
difficulty_level = LEVEL_IDS[LEVELS.index(level)]
|
553 |
|
@@ -555,11 +574,16 @@ def start_new_game(game_name, level, session_state_component, is_solved_componen
|
|
555 |
# elapsed_text = gr.Textbox("N/A", label=f"{game_name}", info=f"{level}", )
|
556 |
# gr.Timer(.3).tick(_calc_time_elapsed, [cur_game_start, elapsed_text, is_solved_component], [elapsed_text])
|
557 |
|
558 |
-
|
|
|
|
|
|
|
559 |
cur_game = (
|
560 |
new_game(game_name, difficulty_level)
|
561 |
if user is None else
|
562 |
preload_game(game_name, difficulty_level, user)
|
|
|
|
|
563 |
)
|
564 |
cur_game.attach_stats_output_(fp_out)
|
565 |
cur_game.flush_stats_(user_info_to_flush=user)
|
@@ -616,8 +640,12 @@ def start_new_game(game_name, level, session_state_component, is_solved_componen
|
|
616 |
js=js_submit)
|
617 |
give_up_checkbox = gr.Checkbox(False, visible=False, interactive=False)
|
618 |
give_up_btn.click(
|
|
|
|
|
619 |
lambda x: x, [give_up_checkbox], [give_up_checkbox],
|
620 |
js="(x) => confirm('🥹 Give-up? 💸')"
|
|
|
|
|
621 |
)
|
622 |
|
623 |
def _forfeiting(confirmed, _solved_games):
|
@@ -640,6 +668,8 @@ def start_new_game(game_name, level, session_state_component, is_solved_componen
|
|
640 |
def game_is_solved(_is_solved, _session_state, _solved_games, progress=gr.Progress()):
|
641 |
if _is_solved:
|
642 |
if level in LEVELS and level not in _solved_games[game_name]:
|
|
|
|
|
643 |
_solved_games[game_name].append(level)
|
644 |
return (
|
645 |
2,
|
@@ -655,8 +685,16 @@ def start_new_game(game_name, level, session_state_component, is_solved_componen
|
|
655 |
|
656 |
def finalize_game(_is_solved):
|
657 |
if _is_solved:
|
658 |
-
gr.Info("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
659 |
upload_to_drive(fp_out)
|
|
|
660 |
return gr.update(interactive=True)
|
661 |
return gr.update()
|
662 |
|
@@ -673,13 +711,14 @@ def start_new_game(game_name, level, session_state_component, is_solved_componen
|
|
673 |
|
674 |
|
675 |
# %%
|
676 |
-
def check_to_start_new_game(game_name, level, user=None, uid=None):
|
677 |
-
print(game_name, level)
|
678 |
if game_name is None or level is None:
|
679 |
raise gr.Error("please choose both Game & Level")
|
680 |
-
fp = _get_file_output(game_name, LEVEL_IDS[LEVELS.index(level)], uid)
|
681 |
if os.path.exists(fp):
|
682 |
-
raise gr.Error(f"You have done this game already.<br/>{game_name} - {level}")
|
|
|
683 |
if user is None:
|
684 |
gr.Warning("no user, game will be generated randomly")
|
685 |
# else:
|
@@ -691,16 +730,19 @@ def check_to_start_new_game(game_name, level, user=None, uid=None):
|
|
691 |
|
692 |
|
693 |
# %%
|
694 |
-
def check_played_game(solved_games,
|
|
|
|
|
695 |
matches = _files.list(
|
696 |
q=f"'{_folder_id}' in parents and mimeType='application/octet-stream' and name contains '{uid}_-_'",
|
697 |
fields=f"files(name, id, {_cksm_methods_str})",
|
698 |
-
).execute()
|
|
|
699 |
ret = dict()
|
700 |
for game_name in solved_games.keys():
|
701 |
cur = []
|
702 |
for level, level_id in zip(LEVELS, LEVEL_IDS):
|
703 |
-
fp_out = _get_file_output(game_name, level_id, uid)
|
704 |
_matches = list(filter(lambda m: fp_out.endswith(m['name']), matches))
|
705 |
if os.path.exists(fp_out):
|
706 |
upload_to_drive(fp_out, _matches)
|
@@ -708,7 +750,7 @@ def check_played_game(solved_games, uid, progress=gr.Progress()):
|
|
708 |
download_from_drive(fp_out, _matches)
|
709 |
if os.path.exists(fp_out):
|
710 |
cur.append(level)
|
711 |
-
ret[game_name] = cur
|
712 |
return ret, gr.update()
|
713 |
|
714 |
|
|
|
1 |
# %%
|
2 |
import os
|
3 |
import time
|
4 |
+
import json
|
5 |
import pandas as pd
|
6 |
import gradio as gr
|
7 |
import hashlib
|
|
|
20 |
from googleapiclient.errors import HttpError
|
21 |
from googleapiclient.http import MediaFileUpload, MediaIoBaseDownload
|
22 |
|
23 |
+
# %%
|
24 |
+
_leaderboards = f"{os.getenv('TEXTGAMES_OUTPUT_DIR', '.')}/_leaderboards.jsonl"
|
25 |
+
|
26 |
|
27 |
# %%
|
28 |
+
def declare_components(demo, greet, use_login_button=False):
|
29 |
with gr.Row():
|
30 |
with gr.Column(scale=1):
|
31 |
m = gr.Markdown("Welcome to TextGames!", elem_id="md-greeting")
|
32 |
+
if use_login_button:
|
33 |
+
logout_btn = gr.LoginButton(size='sm')
|
34 |
+
reset_sid_btn = gr.Button("♻️ Reset Game Progress", variant='huggingface', size='sm')
|
35 |
+
else:
|
36 |
+
logout_btn = gr.Button("Logout", link="/logout", variant='huggingface', size='sm', elem_id="btn-logout")
|
37 |
+
reset_sid_btn = gr.Button(interactive=False, visible=False, size='sm')
|
38 |
with gr.Column(scale=2):
|
39 |
+
solved_games_df = gr.DataFrame(headers=[g.split('\t', 1)[0] for g in GAME_NAMES], label="Attempted Games",
|
40 |
+
row_count=(1, 'fixed'), interactive=False, elem_id="df-solved-games")
|
41 |
+
level_radio = gr.Radio(LEVELS, label="Level", elem_id="radio-level-name", visible=False)
|
42 |
+
game_radio = gr.Radio(GAME_NAMES, label="Game", elem_id="radio-game-name", visible=False)
|
43 |
+
new_game_btn = gr.Button("Start Game", elem_id="btn-start-game", visible=False)
|
44 |
render_toggle = gr.Checkbox(False, visible=False, interactive=False)
|
45 |
|
46 |
# cur_game_start = gr.BrowserState()
|
|
|
50 |
user_state = gr.State()
|
51 |
uid_state = gr.State()
|
52 |
|
53 |
+
if not os.path.exists(_leaderboards):
|
54 |
+
download_from_drive(_leaderboards, compare_checksum=False)
|
55 |
+
|
56 |
session_state.change(
|
57 |
+
lambda s: session_state_change_fn(s, 2, 0, 3, 0),
|
58 |
+
[session_state], [game_radio, level_radio, new_game_btn, logout_btn, reset_sid_btn], js=js_remove_input_helper,
|
59 |
)
|
60 |
new_game_btn.click(check_to_start_new_game, [game_radio, level_radio, user_state, uid_state], [session_state])
|
61 |
solved_games.change(solved_games_change_fn, solved_games, solved_games_df)
|
|
|
66 |
).then(
|
67 |
lambda: gr.update(interactive=False), None, [new_game_btn],
|
68 |
).then(
|
69 |
+
check_played_game, [solved_games, user_state], [solved_games, solved_games_df]
|
70 |
).then(
|
71 |
+
lambda uid: ([gr.update(visible=True, interactive=True)] if uid else
|
72 |
+
[gr.update(visible=False, interactive=False)]) * 3,
|
73 |
+
[uid_state], [level_radio, game_radio, new_game_btn]
|
74 |
)
|
75 |
|
76 |
return (
|
77 |
+
(m, logout_btn, solved_games_df, game_radio, level_radio, new_game_btn, render_toggle, reset_sid_btn),
|
78 |
(session_state, is_solved, solved_games, user_state, uid_state),
|
79 |
)
|
80 |
|
|
|
503 |
matches = _files.list(
|
504 |
q=f"'{_folder_id}' in parents and mimeType='{mime_type}' and name = '{fp_out.rsplit('/', 1)[-1]}'",
|
505 |
fields=f"files(name, id, {_cksm_methods_str})",
|
506 |
+
).execute()
|
507 |
+
matches = matches['files']
|
508 |
if not os.path.exists(fp_out):
|
509 |
return None, None, matches
|
510 |
with open(fp_out, "rb") as o:
|
|
|
517 |
|
518 |
|
519 |
# %%
|
520 |
+
def upload_to_drive(fp_out, matches=None, mime_type="application/octet-stream", compare_checksum=True, update=False):
|
521 |
if compare_checksum:
|
522 |
+
same_checksum, _, matches = _is_checksum_same(fp_out, matches, mime_type)
|
523 |
# same_checksum, _, _ = _is_checksum_same(
|
524 |
# fp_out, **{k: v for k, v in [('matches', matches), ('mime_type', mime_type)] if v})
|
525 |
if same_checksum:
|
|
|
528 |
file_metadata = {"name": fn, "parents": [_folder_id]}
|
529 |
media = MediaFileUpload(fp_out)
|
530 |
try:
|
531 |
+
if update and matches:
|
532 |
+
file_metadata.pop("parents")
|
533 |
+
_files.update(fileId=matches[0]['id'], body=file_metadata, media_body=media).execute()
|
534 |
+
else:
|
535 |
+
_files.create(body=file_metadata, media_body=media).execute()
|
536 |
except HttpError as error:
|
537 |
msg = f"Failed to upload the file, error: {error}"
|
538 |
print(msg)
|
|
|
566 |
|
567 |
# %%
|
568 |
def start_new_game(game_name, level, session_state_component, is_solved_component, solved_games_component,
|
569 |
+
user=None, show_timer=False, uid=None, sid=None):
|
570 |
# cur_game_id = GAME_IDS[GAME_NAMES.index(game_name)]
|
571 |
difficulty_level = LEVEL_IDS[LEVELS.index(level)]
|
572 |
|
|
|
574 |
# elapsed_text = gr.Textbox("N/A", label=f"{game_name}", info=f"{level}", )
|
575 |
# gr.Timer(.3).tick(_calc_time_elapsed, [cur_game_start, elapsed_text, is_solved_component], [elapsed_text])
|
576 |
|
577 |
+
if (not sid) and user and ('sid' in user):
|
578 |
+
sid = user['sid']
|
579 |
+
|
580 |
+
fp_out = _get_file_output(game_name, difficulty_level, f"{uid}_{sid}")
|
581 |
cur_game = (
|
582 |
new_game(game_name, difficulty_level)
|
583 |
if user is None else
|
584 |
preload_game(game_name, difficulty_level, user)
|
585 |
+
if sid is None else
|
586 |
+
preload_game(game_name, difficulty_level, user, sid=sid)
|
587 |
)
|
588 |
cur_game.attach_stats_output_(fp_out)
|
589 |
cur_game.flush_stats_(user_info_to_flush=user)
|
|
|
640 |
js=js_submit)
|
641 |
give_up_checkbox = gr.Checkbox(False, visible=False, interactive=False)
|
642 |
give_up_btn.click(
|
643 |
+
lambda: (gr.update(interactive=False), gr.update(interactive=False)), None, [submit_btn, give_up_btn]
|
644 |
+
).then(
|
645 |
lambda x: x, [give_up_checkbox], [give_up_checkbox],
|
646 |
js="(x) => confirm('🥹 Give-up? 💸')"
|
647 |
+
).then(
|
648 |
+
lambda: (gr.update(interactive=True), gr.update(interactive=True)), None, [submit_btn, give_up_btn]
|
649 |
)
|
650 |
|
651 |
def _forfeiting(confirmed, _solved_games):
|
|
|
668 |
def game_is_solved(_is_solved, _session_state, _solved_games, progress=gr.Progress()):
|
669 |
if _is_solved:
|
670 |
if level in LEVELS and level not in _solved_games[game_name]:
|
671 |
+
if isinstance(_solved_games[game_name], str):
|
672 |
+
_solved_games[game_name] = []
|
673 |
_solved_games[game_name].append(level)
|
674 |
return (
|
675 |
2,
|
|
|
685 |
|
686 |
def finalize_game(_is_solved):
|
687 |
if _is_solved:
|
688 |
+
gr.Info(f"Wrapping things up... Please click the button when available...<br/>"
|
689 |
+
f"Time: {cur_game.end_timestamp-cur_game.start_timestamp:4.1f} sec. Attempt: {cur_game.attempt_count}.")
|
690 |
+
with open(_leaderboards, "a", encoding="utf-8") as f:
|
691 |
+
json.dump({'uid': uid, 'sid': sid, 'turns': cur_game.attempt_count,
|
692 |
+
'st': cur_game.start_timestamp, 'ed': cur_game.end_timestamp,
|
693 |
+
'game_name': game_name, 'difficulty_level': difficulty_level,
|
694 |
+
}, f)
|
695 |
+
f.write("\n")
|
696 |
upload_to_drive(fp_out)
|
697 |
+
upload_to_drive(_leaderboards, update=True)
|
698 |
return gr.update(interactive=True)
|
699 |
return gr.update()
|
700 |
|
|
|
711 |
|
712 |
|
713 |
# %%
|
714 |
+
def check_to_start_new_game(game_name, level, user=None, uid=None, sid=None):
|
715 |
+
print(game_name, level, uid, sid)
|
716 |
if game_name is None or level is None:
|
717 |
raise gr.Error("please choose both Game & Level")
|
718 |
+
fp = _get_file_output(game_name, LEVEL_IDS[LEVELS.index(level)], f"{uid}_{sid}")
|
719 |
if os.path.exists(fp):
|
720 |
+
# raise gr.Error(f"You have done this game already.<br/>{game_name} - {level}")
|
721 |
+
gr.Warning("You have done this game already. Only first attempt is recorded in the scoreboard.")
|
722 |
if user is None:
|
723 |
gr.Warning("no user, game will be generated randomly")
|
724 |
# else:
|
|
|
730 |
|
731 |
|
732 |
# %%
|
733 |
+
def check_played_game(solved_games, user, progress=gr.Progress()):
|
734 |
+
uid = user['email']
|
735 |
+
sid = user.get('sid', None)
|
736 |
matches = _files.list(
|
737 |
q=f"'{_folder_id}' in parents and mimeType='application/octet-stream' and name contains '{uid}_-_'",
|
738 |
fields=f"files(name, id, {_cksm_methods_str})",
|
739 |
+
).execute()
|
740 |
+
matches = matches['files']
|
741 |
ret = dict()
|
742 |
for game_name in solved_games.keys():
|
743 |
cur = []
|
744 |
for level, level_id in zip(LEVELS, LEVEL_IDS):
|
745 |
+
fp_out = _get_file_output(game_name, level_id, f"{uid}_{sid}")
|
746 |
_matches = list(filter(lambda m: fp_out.endswith(m['name']), matches))
|
747 |
if os.path.exists(fp_out):
|
748 |
upload_to_drive(fp_out, _matches)
|
|
|
750 |
download_from_drive(fp_out, _matches)
|
751 |
if os.path.exists(fp_out):
|
752 |
cur.append(level)
|
753 |
+
ret[game_name] = cur or '∅'
|
754 |
return ret, gr.update()
|
755 |
|
756 |
|
play_with_auth.py
CHANGED
@@ -130,7 +130,7 @@ with gr.Blocks(title="TextGames") as login_demo:
|
|
130 |
app = gr.mount_gradio_app(app, login_demo, path="/login")
|
131 |
|
132 |
with gr.Blocks(title="TextGames", css=css, delete_cache=(3600, 3600)) as demo:
|
133 |
-
((m, logout_btn, solved_games_df, game_radio, level_radio, new_game_btn, render_toggle),
|
134 |
(session_state, is_solved, solved_games, user_state, uid_state),
|
135 |
) = declare_components(demo, greet)
|
136 |
|
|
|
130 |
app = gr.mount_gradio_app(app, login_demo, path="/login")
|
131 |
|
132 |
with gr.Blocks(title="TextGames", css=css, delete_cache=(3600, 3600)) as demo:
|
133 |
+
((m, logout_btn, solved_games_df, game_radio, level_radio, new_game_btn, render_toggle, reset_sid_btn),
|
134 |
(session_state, is_solved, solved_games, user_state, uid_state),
|
135 |
) = declare_components(demo, greet)
|
136 |
|
play_with_hf.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
#%%
|
4 |
+
import os
|
5 |
+
# os.environ.setdefault("GRADIO_SERVER_PORT", "1080")
|
6 |
+
# os.environ.setdefault("TEXTGAMES_SHOW_HIDDEN_LEVEL", "1")
|
7 |
+
os.environ.setdefault("TEXTGAMES_LOADGAME_DIR", "problemsets")
|
8 |
+
os.environ.setdefault("TEXTGAMES_LOADGAME_ID", "42")
|
9 |
+
os.environ.setdefault("TEXTGAMES_MOCKUSER", "")
|
10 |
+
os.environ.setdefault("TEXTGAMES_OUTPUT_DIR", "user_outputs")
|
11 |
+
favicon_path = "textgames-scrabble-black2-ss.png"
|
12 |
+
|
13 |
+
#%%
|
14 |
+
from play_helper import css, declare_components, start_new_game, check_played_game, download_from_drive, upload_to_drive, _leaderboards
|
15 |
+
import pandas as pd
|
16 |
+
import gradio as gr
|
17 |
+
import random
|
18 |
+
import json
|
19 |
+
from textgames import GAME_NAMES
|
20 |
+
|
21 |
+
|
22 |
+
#%%
|
23 |
+
os.makedirs(os.getenv('TEXTGAMES_OUTPUT_DIR', '.'), exist_ok=True)
|
24 |
+
|
25 |
+
|
26 |
+
#%%
|
27 |
+
def generate_sid(fp):
|
28 |
+
rand_int = random.randint(0, 1000)
|
29 |
+
with open(fp, "w", encoding="utf8") as f:
|
30 |
+
f.write(f"session_{rand_int:04}\n")
|
31 |
+
upload_to_drive(fp, mime_type="text/plain", update=True)
|
32 |
+
|
33 |
+
|
34 |
+
#%%
|
35 |
+
def get_sid(uid, force_generate_sid=False):
|
36 |
+
fp = f"{os.getenv('TEXTGAMES_OUTPUT_DIR')}/{uid}_sid.txt"
|
37 |
+
if force_generate_sid:
|
38 |
+
generate_sid(fp)
|
39 |
+
if not os.path.exists(fp):
|
40 |
+
download_from_drive(fp, mime_type="text/plain", compare_checksum=False)
|
41 |
+
if not os.path.exists(fp):
|
42 |
+
generate_sid(fp)
|
43 |
+
with open(fp, "r", encoding="utf8") as f:
|
44 |
+
sid = [_ for _ in f][-1]
|
45 |
+
return sid.strip()
|
46 |
+
|
47 |
+
|
48 |
+
#%%
|
49 |
+
def greet(request: gr.OAuthProfile | None):
|
50 |
+
user = {'email': os.getenv('TEXTGAMES_MOCKUSER', ''), 'name': ""}
|
51 |
+
if request is not None:
|
52 |
+
user = {'email': request.username, 'name': request.name, 'sid': get_sid(request.username)}
|
53 |
+
return f"""
|
54 |
+
Welcome to TextGames, {user['name'] or 'please login'}!
|
55 |
+
""", user, user['email']
|
56 |
+
|
57 |
+
|
58 |
+
#%%
|
59 |
+
with gr.Blocks(title="TextGames", css=css, delete_cache=(3600, 3600)) as demo:
|
60 |
+
((m, logout_btn, solved_games_df, game_radio, level_radio, new_game_btn, render_toggle, reset_sid_btn),
|
61 |
+
(session_state, is_solved, solved_games, user_state, uid_state),
|
62 |
+
) = declare_components(demo, greet, use_login_button=True)
|
63 |
+
logout_btn.activate()
|
64 |
+
|
65 |
+
reset_sid_checkbox = gr.Checkbox(False, visible=False, interactive=False)
|
66 |
+
reset_sid_btn.click(
|
67 |
+
lambda: [gr.update(interactive=False)]*2, None, [reset_sid_btn, new_game_btn]
|
68 |
+
).then(
|
69 |
+
lambda x: x, [reset_sid_checkbox], [reset_sid_checkbox],
|
70 |
+
js="(x) => confirm('Reset Progress? (cannot be undone)')"
|
71 |
+
).then(
|
72 |
+
lambda: [gr.update(interactive=True)]*2, None, [reset_sid_btn, new_game_btn]
|
73 |
+
)
|
74 |
+
|
75 |
+
def _resetting(confirmed, user):
|
76 |
+
uid = user.get('email', None) if isinstance(user, dict) else None
|
77 |
+
if uid is None:
|
78 |
+
gr.Warning("You need to log in first!")
|
79 |
+
elif confirmed:
|
80 |
+
user['sid'] = get_sid(uid, force_generate_sid=True)
|
81 |
+
return user, False
|
82 |
+
reset_sid_checkbox.change(
|
83 |
+
lambda: [gr.update(interactive=False)]*3, None, [logout_btn, reset_sid_btn, new_game_btn]
|
84 |
+
).then(
|
85 |
+
_resetting, [reset_sid_checkbox, user_state], [user_state, reset_sid_checkbox]
|
86 |
+
).then(
|
87 |
+
check_played_game, [solved_games, user_state], [solved_games, solved_games_df]
|
88 |
+
).then(
|
89 |
+
lambda: [gr.update(interactive=True)]*3, None, [logout_btn, reset_sid_btn, new_game_btn]
|
90 |
+
)
|
91 |
+
|
92 |
+
|
93 |
+
@gr.render(inputs=[game_radio, level_radio, user_state, session_state, uid_state], triggers=[render_toggle.change])
|
94 |
+
def _start_new_game(game_name, level, user, _session_state, _uid_state):
|
95 |
+
if _session_state in [1, 2]:
|
96 |
+
start_new_game(game_name, level, session_state, is_solved, solved_games, user=user, uid=_uid_state)
|
97 |
+
|
98 |
+
#%%
|
99 |
+
with demo.route("Leaderboards", "/leaderboard") as demo_leaderboard:
|
100 |
+
gr.Markdown("Under Construction. Will be available soon.")
|
101 |
+
leaderboards = []
|
102 |
+
for tab in ["🚅 Easy", "🚀 Medium", "🛸 Hard"]:
|
103 |
+
with gr.Tab(tab):
|
104 |
+
leaderboards.append(gr.DataFrame(label="Rankings"))
|
105 |
+
|
106 |
+
# if os.path.exists(_leaderboards):
|
107 |
+
# datas = []
|
108 |
+
# with open(_leaderboards, "r", encoding="utf8") as f:
|
109 |
+
# for line in f:
|
110 |
+
# datas.append(json.loads(line))
|
111 |
+
# concat = [{'Level': d['difficulty_level'], 'User': d['uid'], 'Game': d['game_name'].split('\t', 1)[0], 'Attempts': d['turns'],
|
112 |
+
# "Time": d['ed'] - d['st']} for d in datas]
|
113 |
+
# else:
|
114 |
+
def add_dummies():
|
115 |
+
return pd.DataFrame({
|
116 |
+
'User': ['dummy'],
|
117 |
+
'Solved': [' '.join([g.split('\t', 1)[0] for g in GAME_NAMES])],
|
118 |
+
'Attempts': [8],
|
119 |
+
'Time': [7200.8],
|
120 |
+
})
|
121 |
+
for l in leaderboards:
|
122 |
+
demo_leaderboard.load(add_dummies, None, [l])
|
123 |
+
|
124 |
+
|
125 |
+
#%%
|
126 |
+
# demo.launch()
|
127 |
+
demo.launch(
|
128 |
+
favicon_path=favicon_path if os.path.exists(favicon_path) else None,
|
129 |
+
show_api=False,
|
130 |
+
)
|
131 |
+
|
132 |
+
|
problemsets/Anagram Scribble_1.json
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
problemsets/Anagram Scribble_2.json
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
problemsets/Anagram Scribble_3.json
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
problemsets/Bracket Game_1.json
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
problemsets/Bracket Game_2.json
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
problemsets/Bracket Game_3.json
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
problemsets/Crossword Arranger_1.json
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
problemsets/Crossword Arranger_2.json
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
problemsets/Crossword Arranger_3.json
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
reval_ana3.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
from tqdm import tqdm
|
4 |
+
from textgames import GAME_NAMES, game_filename, _game_class_from_name
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
GAME_NAME = GAME_NAMES[5]
|
8 |
+
PROBLEMSET_DIR = Path(os.getenv("TG_PROBLEMSET_DIR", "problemsets"))
|
9 |
+
MODEL_OUTPUT_DIR = Path(os.getenv("TG_MODEL_OUTPUT_DIR", "model_outputs"))
|
10 |
+
OUTPUT_FILENAMES = [
|
11 |
+
# "results_gemma-2-9b-it.1s.jsonl",
|
12 |
+
# "results_gemma-2-9b-it.zs.jsonl",
|
13 |
+
# "results_gemma-2-27b-it.1s.jsonl",
|
14 |
+
# "results_gemma-2-27b-it.zs.jsonl",
|
15 |
+
#
|
16 |
+
# "results_llama-3.1-8b-instruct.1s.jsonl",
|
17 |
+
# "results_llama-3.1-8b-instruct.zs.jsonl",
|
18 |
+
# "results_llama-3.1-70b-instruct.1s.jsonl",
|
19 |
+
# "results_llama-3.1-70b-instruct.zs.jsonl",
|
20 |
+
# "results_llama-3.3-70b-instruct.1s.jsonl",
|
21 |
+
# "results_llama-3.3-70b-instruct.zs.jsonl",
|
22 |
+
#
|
23 |
+
# "results_qwen2-5-7b-instruct.1s.jsonl",
|
24 |
+
# "results_qwen2-5-7b-instruct.zs.jsonl",
|
25 |
+
# "results_qwen2-5-14b-instruct.1s.jsonl",
|
26 |
+
# "results_qwen2-5-14b-instruct.zs.jsonl",
|
27 |
+
# "results_qwen2-5-32b-instruct.1s.jsonl",
|
28 |
+
# "results_qwen2-5-32b-instruct.zs.jsonl",
|
29 |
+
# "results_qwen2-5-72b-instruct.1s.jsonl",
|
30 |
+
# "results_qwen2-5-72b-instruct.zs.jsonl",
|
31 |
+
#
|
32 |
+
# "results_deepseek-r1-distill-14b.1s.jsonl",
|
33 |
+
# "results_deepseek-r1-distill-14b.zs.jsonl",
|
34 |
+
# "results_deepseek-r1-distill-14b.rerun.1s.jsonl",
|
35 |
+
#
|
36 |
+
# "results_chatgpt-4o-mini.zs.jsonl",
|
37 |
+
# "results_chatgpt-o3-mini.zs.jsonl",
|
38 |
+
#
|
39 |
+
# "results_qwen2-5-7b-instruct_sp.1s.jsonl",
|
40 |
+
# "results_qwen2-5-7b-instruct_sp.zs.jsonl",
|
41 |
+
|
42 |
+
# "results_deepseek-r1-distill-8b.1s.jsonl",
|
43 |
+
"results_deepseek-r1-distill-8b.zs.jsonl",
|
44 |
+
]
|
45 |
+
|
46 |
+
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
47 |
+
# !!! Must run reval_bracket_rerun.py first !!!
|
48 |
+
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
49 |
+
|
50 |
+
|
51 |
+
def revalidate_anagram_3(fp, reval_dir="revalidate_anagram_3", source_dir="prior_revalidate"):
|
52 |
+
os.makedirs(MODEL_OUTPUT_DIR/reval_dir, exist_ok=True)
|
53 |
+
count_pos, count_neg = 0, 0
|
54 |
+
with (open(MODEL_OUTPUT_DIR/source_dir/fp, "r", encoding="utf8") as i,
|
55 |
+
open(MODEL_OUTPUT_DIR/reval_dir/fp, "w", encoding="utf8") as o,
|
56 |
+
tqdm(total=1000, desc=fp) as pbar,
|
57 |
+
):
|
58 |
+
for line in i:
|
59 |
+
res = json.loads(line)
|
60 |
+
if (res['game'] == f"{game_filename(GAME_NAME)}_3"):
|
61 |
+
if (res['turn'] == 1):
|
62 |
+
cur_sid = res["session"]
|
63 |
+
prompt = sid_prompt_dict[cur_sid]
|
64 |
+
cur_game = game_cls()
|
65 |
+
cur_game.load_game(prompt)
|
66 |
+
pbar.update(1)
|
67 |
+
elif solved == True:
|
68 |
+
continue
|
69 |
+
else:
|
70 |
+
assert cur_sid == res["session"]
|
71 |
+
solved, _ = cur_game.validate(res["response"])
|
72 |
+
if solved and not res["solved"]:
|
73 |
+
count_pos += 1
|
74 |
+
elif not solved and res["solved"]:
|
75 |
+
count_neg += 1
|
76 |
+
res["solved"] = solved
|
77 |
+
o.write(json.dumps(res))
|
78 |
+
o.write("\n")
|
79 |
+
return count_pos, count_neg
|
80 |
+
|
81 |
+
|
82 |
+
if __name__ == "__main__":
|
83 |
+
game_cls = _game_class_from_name(GAME_NAME)
|
84 |
+
with open(f"{PROBLEMSET_DIR}/{game_filename(GAME_NAME)}_3.json", "r", encoding="utf8") as f:
|
85 |
+
sid_prompt_dict = json.load(f)
|
86 |
+
for fp in OUTPUT_FILENAMES:
|
87 |
+
print(revalidate_anagram_3(fp))
|
reval_bracket_all.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
from tqdm import tqdm
|
4 |
+
from textgames import GAME_NAMES, game_filename, _game_class_from_name
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
GAME_NAME = GAME_NAMES[6]
|
8 |
+
PROBLEMSET_DIR = Path(os.getenv("TG_PROBLEMSET_DIR", "problemsets"))
|
9 |
+
MODEL_OUTPUT_DIR = Path(os.getenv("TG_MODEL_OUTPUT_DIR", "model_outputs"))
|
10 |
+
OUTPUT_FILENAMES = [
|
11 |
+
# "results_gemma-2-9b-it.1s.jsonl",
|
12 |
+
# "results_gemma-2-9b-it.zs.jsonl",
|
13 |
+
# "results_gemma-2-27b-it.1s.jsonl",
|
14 |
+
# "results_gemma-2-27b-it.zs.jsonl",
|
15 |
+
#
|
16 |
+
# "results_llama-3.1-8b-instruct.1s.jsonl",
|
17 |
+
# "results_llama-3.1-8b-instruct.zs.jsonl",
|
18 |
+
# "results_llama-3.1-70b-instruct.1s.jsonl",
|
19 |
+
# "results_llama-3.1-70b-instruct.zs.jsonl",
|
20 |
+
# "results_llama-3.3-70b-instruct.1s.jsonl",
|
21 |
+
# "results_llama-3.3-70b-instruct.zs.jsonl",
|
22 |
+
#
|
23 |
+
# "results_qwen2-5-7b-instruct.1s.jsonl",
|
24 |
+
# "results_qwen2-5-7b-instruct.zs.jsonl",
|
25 |
+
# "results_qwen2-5-14b-instruct.1s.jsonl",
|
26 |
+
# "results_qwen2-5-14b-instruct.zs.jsonl",
|
27 |
+
# "results_qwen2-5-32b-instruct.1s.jsonl",
|
28 |
+
# "results_qwen2-5-32b-instruct.zs.jsonl",
|
29 |
+
# "results_qwen2-5-72b-instruct.1s.jsonl",
|
30 |
+
# "results_qwen2-5-72b-instruct.zs.jsonl",
|
31 |
+
#
|
32 |
+
# "results_deepseek-r1-distill-14b.1s.jsonl",
|
33 |
+
# "results_deepseek-r1-distill-14b.zs.jsonl",
|
34 |
+
# "results_deepseek-r1-distill-14b.rerun.1s.jsonl",
|
35 |
+
#
|
36 |
+
# "results_chatgpt-4o-mini.1s.jsonl",
|
37 |
+
# "results_chatgpt-4o-mini.zs.jsonl",
|
38 |
+
# "results_chatgpt-o3-mini.zs.jsonl",
|
39 |
+
#
|
40 |
+
# "results_qwen2-5-7b-instruct_sp.1s.jsonl",
|
41 |
+
# "results_qwen2-5-7b-instruct_sp.zs.jsonl",
|
42 |
+
|
43 |
+
# "results_deepseek-r1-distill-8b.1s.jsonl",
|
44 |
+
"results_deepseek-r1-distill-8b.zs.jsonl",
|
45 |
+
]
|
46 |
+
|
47 |
+
|
48 |
+
def revalidate_bracket(fp, reval_dir="revalidate_bracket_all",
|
49 |
+
source_dirs=("revalidate_bracket_rerun", "revalidate_anagram_3",)):
|
50 |
+
os.makedirs(MODEL_OUTPUT_DIR/reval_dir, exist_ok=True)
|
51 |
+
count_pos, count_neg = 0, 0
|
52 |
+
source_dir = "."
|
53 |
+
for source_dir in source_dirs:
|
54 |
+
if (MODEL_OUTPUT_DIR / source_dir / fp).exists():
|
55 |
+
break
|
56 |
+
with (open(MODEL_OUTPUT_DIR / source_dir / fp, "r", encoding="utf8") as i,
|
57 |
+
open(MODEL_OUTPUT_DIR / reval_dir / fp, "w", encoding="utf8") as o,
|
58 |
+
tqdm(total=3000, desc=fp) as pbar,
|
59 |
+
):
|
60 |
+
for line in i:
|
61 |
+
res = json.loads(line)
|
62 |
+
if (res['game'].startswith(f"{game_filename(GAME_NAME)}")):
|
63 |
+
sid_prompt_dict = sid_prompt_dicts[res['game'].rsplit("_", 1)[-1]]
|
64 |
+
if (res['turn'] == 1):
|
65 |
+
cur_sid = res["session"]
|
66 |
+
prompt = sid_prompt_dict[cur_sid]
|
67 |
+
cur_game = game_cls()
|
68 |
+
cur_game.load_game(prompt)
|
69 |
+
pbar.update(1)
|
70 |
+
elif solved == True:
|
71 |
+
continue
|
72 |
+
else:
|
73 |
+
assert cur_sid == res["session"]
|
74 |
+
solved, _ = cur_game.validate(res["response"])
|
75 |
+
if solved and not res["solved"]:
|
76 |
+
count_pos += 1
|
77 |
+
elif not solved and res["solved"]:
|
78 |
+
count_neg += 1
|
79 |
+
res["solved"] = solved
|
80 |
+
o.write(json.dumps(res))
|
81 |
+
o.write("\n")
|
82 |
+
return count_pos, count_neg
|
83 |
+
|
84 |
+
|
85 |
+
if __name__ == "__main__":
|
86 |
+
def load(k):
|
87 |
+
with open(f"{PROBLEMSET_DIR}/{game_filename(GAME_NAME)}_{k}.json", "r", encoding="utf8") as f:
|
88 |
+
sid_prompt_dict = json.load(f)
|
89 |
+
return sid_prompt_dict
|
90 |
+
sid_prompt_dicts = {k: load(k) for k in map(str, range(1, 4))}
|
91 |
+
game_cls = _game_class_from_name(GAME_NAME)
|
92 |
+
for fp in OUTPUT_FILENAMES:
|
93 |
+
print(revalidate_bracket(fp))
|
94 |
+
|
reval_bracket_rerun.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @title ##### Combine Rerun of the Bracket - All
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
from tqdm import tqdm
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
MODEL_OUTPUT_DIR = Path(os.getenv("TG_MODEL_OUTPUT_DIR", "model_outputs"))
|
8 |
+
fd_new = MODEL_OUTPUT_DIR / "__runs__" / "_redo_bracket"
|
9 |
+
fd_ori = MODEL_OUTPUT_DIR / "revalidate_anagram_3"
|
10 |
+
fd_out = MODEL_OUTPUT_DIR / "revalidate_bracket_rerun"
|
11 |
+
|
12 |
+
OUTPUT_FILENAMES = [
|
13 |
+
"results_gemma-2-9b-it.1s.jsonl",
|
14 |
+
"results_gemma-2-9b-it.zs.jsonl",
|
15 |
+
"results_gemma-2-27b-it.1s.jsonl",
|
16 |
+
"results_gemma-2-27b-it.zs.jsonl",
|
17 |
+
|
18 |
+
"results_llama-3.1-8b-instruct.1s.jsonl",
|
19 |
+
"results_llama-3.1-8b-instruct.zs.jsonl",
|
20 |
+
"results_llama-3.1-70b-instruct.1s.jsonl",
|
21 |
+
"results_llama-3.1-70b-instruct.zs.jsonl",
|
22 |
+
"results_llama-3.3-70b-instruct.1s.jsonl",
|
23 |
+
"results_llama-3.3-70b-instruct.zs.jsonl",
|
24 |
+
|
25 |
+
"results_qwen2-5-7b-instruct.1s.jsonl",
|
26 |
+
"results_qwen2-5-7b-instruct.zs.jsonl",
|
27 |
+
"results_qwen2-5-14b-instruct.1s.jsonl",
|
28 |
+
"results_qwen2-5-14b-instruct.zs.jsonl",
|
29 |
+
"results_qwen2-5-32b-instruct.1s.jsonl",
|
30 |
+
"results_qwen2-5-32b-instruct.zs.jsonl",
|
31 |
+
"results_qwen2-5-72b-instruct.1s.jsonl",
|
32 |
+
"results_qwen2-5-72b-instruct.zs.jsonl",
|
33 |
+
]
|
34 |
+
|
35 |
+
os.makedirs(fd_out, exist_ok=True)
|
36 |
+
for fp in tqdm(OUTPUT_FILENAMES):
|
37 |
+
with open(fd_out / fp, "w", encoding="utf8") as o:
|
38 |
+
with open(fd_ori / fp, "r", encoding="utf8") as i:
|
39 |
+
for line in i:
|
40 |
+
res = json.loads(line)
|
41 |
+
if res['game'].startswith("Bracket Game"):
|
42 |
+
continue
|
43 |
+
o.write(line)
|
44 |
+
with open((fd_new / fp).with_suffix(".6.jsonl"), "r", encoding="utf8") as i:
|
45 |
+
for line in i:
|
46 |
+
o.write(line)
|
reval_crosswords_all.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
from tqdm import tqdm
|
4 |
+
from textgames import GAME_NAMES, game_filename, _game_class_from_name
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
GAME_NAME = GAME_NAMES[0]
|
8 |
+
PROBLEMSET_DIR = Path(os.getenv("TG_PROBLEMSET_DIR", "problemsets"))
|
9 |
+
MODEL_OUTPUT_DIR = Path(os.getenv("TG_MODEL_OUTPUT_DIR", "model_outputs"))
|
10 |
+
OUTPUT_FILENAMES = [
|
11 |
+
# "results_gemma-2-9b-it.1s.jsonl",
|
12 |
+
# "results_gemma-2-9b-it.zs.jsonl",
|
13 |
+
# "results_gemma-2-27b-it.1s.jsonl",
|
14 |
+
# "results_gemma-2-27b-it.zs.jsonl",
|
15 |
+
#
|
16 |
+
# "results_llama-3.1-8b-instruct.1s.jsonl",
|
17 |
+
# "results_llama-3.1-8b-instruct.zs.jsonl",
|
18 |
+
# "results_llama-3.1-70b-instruct.1s.jsonl",
|
19 |
+
# "results_llama-3.1-70b-instruct.zs.jsonl",
|
20 |
+
# "results_llama-3.3-70b-instruct.1s.jsonl",
|
21 |
+
# "results_llama-3.3-70b-instruct.zs.jsonl",
|
22 |
+
#
|
23 |
+
# "results_qwen2-5-7b-instruct.1s.jsonl",
|
24 |
+
# "results_qwen2-5-7b-instruct.zs.jsonl",
|
25 |
+
# "results_qwen2-5-14b-instruct.1s.jsonl",
|
26 |
+
# "results_qwen2-5-14b-instruct.zs.jsonl",
|
27 |
+
# "results_qwen2-5-32b-instruct.1s.jsonl",
|
28 |
+
# "results_qwen2-5-32b-instruct.zs.jsonl",
|
29 |
+
# "results_qwen2-5-72b-instruct.1s.jsonl",
|
30 |
+
# "results_qwen2-5-72b-instruct.zs.jsonl",
|
31 |
+
#
|
32 |
+
# "results_deepseek-r1-distill-14b.1s.jsonl",
|
33 |
+
# "results_deepseek-r1-distill-14b.zs.jsonl",
|
34 |
+
# # "results_deepseek-r1-distill-14b.rerun.1s.jsonl",
|
35 |
+
#
|
36 |
+
# "results_chatgpt-4o-mini.1s.jsonl",
|
37 |
+
# "results_chatgpt-4o-mini.zs.jsonl",
|
38 |
+
# "results_chatgpt-o3-mini.zs.jsonl",
|
39 |
+
#
|
40 |
+
# # "results_qwen2-5-7b-instruct_sp.1s.jsonl",
|
41 |
+
# # "results_qwen2-5-7b-instruct_sp.zs.jsonl",
|
42 |
+
|
43 |
+
"results_deepseek-r1-distill-8b.1s.jsonl",
|
44 |
+
"results_deepseek-r1-distill-8b.zs.jsonl",
|
45 |
+
]
|
46 |
+
|
47 |
+
|
48 |
+
def revalidate_bracket(fp, reval_dir="revalidate_crosswords_all",
|
49 |
+
source_dirs=("revalidate_bracket_all",)):
|
50 |
+
os.makedirs(MODEL_OUTPUT_DIR/reval_dir, exist_ok=True)
|
51 |
+
count_pos, count_neg = 0, 0
|
52 |
+
source_dir = "."
|
53 |
+
for source_dir in source_dirs:
|
54 |
+
if (MODEL_OUTPUT_DIR / source_dir / fp).exists():
|
55 |
+
break
|
56 |
+
with (open(MODEL_OUTPUT_DIR / source_dir / fp, "r", encoding="utf8") as i,
|
57 |
+
open(MODEL_OUTPUT_DIR / reval_dir / fp, "w", encoding="utf8") as o,
|
58 |
+
tqdm(total=3000, desc=fp) as pbar,
|
59 |
+
):
|
60 |
+
for line in i:
|
61 |
+
res = json.loads(line)
|
62 |
+
if (res['game'].startswith(f"{game_filename(GAME_NAME)}")):
|
63 |
+
sid_prompt_dict = sid_prompt_dicts[res['game'].rsplit("_", 1)[-1]]
|
64 |
+
if (res['turn'] == 1):
|
65 |
+
cur_sid = res["session"]
|
66 |
+
prompt = sid_prompt_dict[cur_sid]
|
67 |
+
cur_game = game_cls()
|
68 |
+
cur_game.load_game(prompt)
|
69 |
+
pbar.update(1)
|
70 |
+
elif solved == True:
|
71 |
+
continue
|
72 |
+
else:
|
73 |
+
assert cur_sid == res["session"]
|
74 |
+
solved, _ = cur_game.validate(res["response"])
|
75 |
+
if solved and not res["solved"]:
|
76 |
+
count_pos += 1
|
77 |
+
elif not solved and res["solved"]:
|
78 |
+
count_neg += 1
|
79 |
+
res["solved"] = solved
|
80 |
+
o.write(json.dumps(res))
|
81 |
+
o.write("\n")
|
82 |
+
return count_pos, count_neg
|
83 |
+
|
84 |
+
|
85 |
+
if __name__ == "__main__":
|
86 |
+
def load(k):
|
87 |
+
with open(f"{PROBLEMSET_DIR}/{game_filename(GAME_NAME)}_{k}.json", "r", encoding="utf8") as f:
|
88 |
+
sid_prompt_dict = json.load(f)
|
89 |
+
return sid_prompt_dict
|
90 |
+
sid_prompt_dicts = {k: load(k) for k in map(str, range(1, 4))}
|
91 |
+
game_cls = _game_class_from_name(GAME_NAME)
|
92 |
+
for fp in OUTPUT_FILENAMES:
|
93 |
+
print(revalidate_bracket(fp))
|
94 |
+
|
reval_sudoku_all.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
from tqdm import tqdm
|
4 |
+
from textgames import GAME_NAMES, game_filename, _game_class_from_name
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
GAME_NAME = GAME_NAMES[1]
|
8 |
+
PROBLEMSET_DIR = Path(os.getenv("TG_PROBLEMSET_DIR", "problemsets"))
|
9 |
+
MODEL_OUTPUT_DIR = Path(os.getenv("TG_MODEL_OUTPUT_DIR", "model_outputs"))
|
10 |
+
OUTPUT_FILENAMES = [
|
11 |
+
# "results_gemma-2-9b-it.1s.jsonl",
|
12 |
+
# "results_gemma-2-9b-it.zs.jsonl",
|
13 |
+
# "results_gemma-2-27b-it.1s.jsonl",
|
14 |
+
# "results_gemma-2-27b-it.zs.jsonl",
|
15 |
+
#
|
16 |
+
# "results_llama-3.1-8b-instruct.1s.jsonl",
|
17 |
+
# "results_llama-3.1-8b-instruct.zs.jsonl",
|
18 |
+
# "results_llama-3.1-70b-instruct.1s.jsonl",
|
19 |
+
# "results_llama-3.1-70b-instruct.zs.jsonl",
|
20 |
+
# "results_llama-3.3-70b-instruct.1s.jsonl",
|
21 |
+
# "results_llama-3.3-70b-instruct.zs.jsonl",
|
22 |
+
#
|
23 |
+
# "results_qwen2-5-7b-instruct.1s.jsonl",
|
24 |
+
# "results_qwen2-5-7b-instruct.zs.jsonl",
|
25 |
+
# "results_qwen2-5-14b-instruct.1s.jsonl",
|
26 |
+
# "results_qwen2-5-14b-instruct.zs.jsonl",
|
27 |
+
# "results_qwen2-5-32b-instruct.1s.jsonl",
|
28 |
+
# "results_qwen2-5-32b-instruct.zs.jsonl",
|
29 |
+
# "results_qwen2-5-72b-instruct.1s.jsonl",
|
30 |
+
# "results_qwen2-5-72b-instruct.zs.jsonl",
|
31 |
+
#
|
32 |
+
# "results_deepseek-r1-distill-14b.1s.jsonl",
|
33 |
+
# "results_deepseek-r1-distill-14b.zs.jsonl",
|
34 |
+
# # "results_deepseek-r1-distill-14b.rerun.1s.jsonl",
|
35 |
+
#
|
36 |
+
# "results_chatgpt-4o-mini.1s.jsonl",
|
37 |
+
# "results_chatgpt-4o-mini.zs.jsonl",
|
38 |
+
# "results_chatgpt-o3-mini.zs.jsonl",
|
39 |
+
#
|
40 |
+
# # "results_qwen2-5-7b-instruct_sp.1s.jsonl",
|
41 |
+
# # "results_qwen2-5-7b-instruct_sp.zs.jsonl",
|
42 |
+
|
43 |
+
"results_deepseek-r1-distill-8b.1s.jsonl",
|
44 |
+
"results_deepseek-r1-distill-8b.zs.jsonl",
|
45 |
+
]
|
46 |
+
|
47 |
+
|
48 |
+
def revalidate_bracket(fp, reval_dir="revalidate_sudoku_all",
|
49 |
+
source_dirs=("revalidate_crosswords_all",)):
|
50 |
+
os.makedirs(MODEL_OUTPUT_DIR/reval_dir, exist_ok=True)
|
51 |
+
count_pos, count_neg = 0, 0
|
52 |
+
source_dir = "."
|
53 |
+
for source_dir in source_dirs:
|
54 |
+
if (MODEL_OUTPUT_DIR / source_dir / fp).exists():
|
55 |
+
break
|
56 |
+
with (open(MODEL_OUTPUT_DIR / source_dir / fp, "r", encoding="utf8") as i,
|
57 |
+
open(MODEL_OUTPUT_DIR / reval_dir / fp, "w", encoding="utf8") as o,
|
58 |
+
tqdm(total=3000, desc=fp) as pbar,
|
59 |
+
):
|
60 |
+
for line in i:
|
61 |
+
res = json.loads(line)
|
62 |
+
if (res['game'].startswith(f"{game_filename(GAME_NAME)}")):
|
63 |
+
sid_prompt_dict = sid_prompt_dicts[res['game'].rsplit("_", 1)[-1]]
|
64 |
+
if (res['turn'] == 1):
|
65 |
+
cur_sid = res["session"]
|
66 |
+
prompt = sid_prompt_dict[cur_sid]
|
67 |
+
cur_game = game_cls()
|
68 |
+
cur_game.load_game(prompt)
|
69 |
+
pbar.update(1)
|
70 |
+
elif solved == True:
|
71 |
+
continue
|
72 |
+
else:
|
73 |
+
assert cur_sid == res["session"]
|
74 |
+
solved, _ = cur_game.validate(res["response"])
|
75 |
+
if solved and not res["solved"]:
|
76 |
+
count_pos += 1
|
77 |
+
elif not solved and res["solved"]:
|
78 |
+
count_neg += 1
|
79 |
+
res["solved"] = solved
|
80 |
+
o.write(json.dumps(res))
|
81 |
+
o.write("\n")
|
82 |
+
return count_pos, count_neg
|
83 |
+
|
84 |
+
|
85 |
+
if __name__ == "__main__":
|
86 |
+
def load(k):
|
87 |
+
with open(f"{PROBLEMSET_DIR}/{game_filename(GAME_NAME)}_{k}.json", "r", encoding="utf8") as f:
|
88 |
+
sid_prompt_dict = json.load(f)
|
89 |
+
return sid_prompt_dict
|
90 |
+
sid_prompt_dicts = {k: load(k) for k in map(str, range(1, 4))}
|
91 |
+
game_cls = _game_class_from_name(GAME_NAME)
|
92 |
+
for fp in OUTPUT_FILENAMES:
|
93 |
+
print(revalidate_bracket(fp))
|
94 |
+
|
textgames-scrabble-black2-ss.png
CHANGED
![]() |
![]() |
Git LFS Details
|
textgames/__init__.py
CHANGED
@@ -14,8 +14,10 @@ from pandas import read_csv
|
|
14 |
import json
|
15 |
|
16 |
|
17 |
-
# [
|
18 |
-
# "
|
|
|
|
|
19 |
THE_GAMES = {
|
20 |
k: v.get_game_name() for k, v in [
|
21 |
("1", CrosswordArrangerGame),
|
@@ -60,12 +62,13 @@ def _game_class_from_name(game_name):
|
|
60 |
return None
|
61 |
|
62 |
|
63 |
-
def preload_game(game_name, level_id, user):
|
64 |
game_cls = _game_class_from_name(game_name)
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
69 |
print(f"preload_game('{game_name}', '{level_id}', '{user['email']}') on {sid}")
|
70 |
|
71 |
with open(f"problemsets/{game_filename(game_name)}_{level_id}.json", "r", encoding="utf8") as f:
|
|
|
14 |
import json
|
15 |
|
16 |
|
17 |
+
# [
|
18 |
+
# "📰\tCrossword Arranger", "🧩\tText Sudoku", "🏝️\tIslands", "🔑\tPassword Game",
|
19 |
+
# "📈\tOrdering Text", "🔤\tAnagram Scribble", "🗳️\tBracket Game", "🔎\tString Search",
|
20 |
+
# ]
|
21 |
THE_GAMES = {
|
22 |
k: v.get_game_name() for k, v in [
|
23 |
("1", CrosswordArrangerGame),
|
|
|
62 |
return None
|
63 |
|
64 |
|
65 |
+
def preload_game(game_name, level_id, user, sid=None):
|
66 |
game_cls = _game_class_from_name(game_name)
|
67 |
+
if not sid:
|
68 |
+
email_sid_dict = read_csv(
|
69 |
+
f"{os.getenv('TEXTGAMES_OUTPUT_DIR')}/textgames_userauth.tsv", sep='\t'
|
70 |
+
).dropna().set_index("EMAIL").SID.to_dict()
|
71 |
+
sid = email_sid_dict.get(user["email"])
|
72 |
print(f"preload_game('{game_name}', '{level_id}', '{user['email']}') on {sid}")
|
73 |
|
74 |
with open(f"problemsets/{game_filename(game_name)}_{level_id}.json", "r", encoding="utf8") as f:
|
textgames/anagram_scribble/anagram_scribble.py
CHANGED
@@ -5,6 +5,7 @@ import json
|
|
5 |
import string
|
6 |
import re
|
7 |
|
|
|
8 |
class AnagramScribble(BaseGame):
|
9 |
@staticmethod
|
10 |
def get_game_name() -> str:
|
@@ -43,6 +44,18 @@ class AnagramScribble(BaseGame):
|
|
43 |
if total_chars_extraction != "Error loading game state.":
|
44 |
characters = total_chars_extraction.split(",")
|
45 |
self.total_chars = [char.strip().strip("'") for char in characters]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
def _generate_new_game(self, *args, **kwargs) -> None:
|
48 |
self.low_num_chars = kwargs['low_num_chars']
|
@@ -57,16 +70,16 @@ class AnagramScribble(BaseGame):
|
|
57 |
|
58 |
def _get_prompt(self) -> str:
|
59 |
if self.allow_repeat:
|
60 |
-
prompt = f"Construct a valid {self.num_chars}-character English word from the following letters:\n{self.total_chars}.\nEach character can be used multiple times. Please write None if there is no valid combination."
|
61 |
else:
|
62 |
-
prompt = f"Construct a valid {self.num_chars}-character English word from the following letters:\n{self.total_chars}.\nEach character can only be used once. Please write None if there is no valid combination."
|
63 |
return prompt
|
64 |
|
65 |
def _validate(self, answer: str) -> (bool, str):
|
66 |
-
answer
|
67 |
-
if self.possible_ans != "" and answer == "none":
|
68 |
val_msg = "There is a valid answer."
|
69 |
return False, val_msg
|
|
|
70 |
if len(answer) != self.num_chars:
|
71 |
val_msg = f"Your answer must be exactly {self.num_chars} characters long"
|
72 |
return False, val_msg
|
@@ -74,12 +87,31 @@ class AnagramScribble(BaseGame):
|
|
74 |
if char not in self.total_chars:
|
75 |
val_msg = "Your answer must only contain the characters provided"
|
76 |
return False, val_msg
|
77 |
-
if (not self.allow_repeat and (len(set(answer)) != len(answer))
|
78 |
-
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
if answer not in self.WORD_LIST_BIN[str(self.num_chars)]:
|
82 |
val_msg = "Your answer is not a valid English word"
|
83 |
return False, val_msg
|
84 |
|
85 |
return True, ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
import string
|
6 |
import re
|
7 |
|
8 |
+
|
9 |
class AnagramScribble(BaseGame):
|
10 |
@staticmethod
|
11 |
def get_game_name() -> str:
|
|
|
44 |
if total_chars_extraction != "Error loading game state.":
|
45 |
characters = total_chars_extraction.split(",")
|
46 |
self.total_chars = [char.strip().strip("'") for char in characters]
|
47 |
+
self.possible_ans = ""
|
48 |
+
_chars = sorted(self.total_chars)
|
49 |
+
for w in self.WORD_LIST_BIN[str(self.num_chars)]:
|
50 |
+
_ans = sorted(w)
|
51 |
+
j, k = 0, 0
|
52 |
+
while j < len(_ans) and k < len(_chars):
|
53 |
+
if _ans[j] == _chars[k]:
|
54 |
+
j += 1
|
55 |
+
k += 1
|
56 |
+
if j >= len(_ans):
|
57 |
+
self.possible_ans = w
|
58 |
+
break
|
59 |
|
60 |
def _generate_new_game(self, *args, **kwargs) -> None:
|
61 |
self.low_num_chars = kwargs['low_num_chars']
|
|
|
70 |
|
71 |
def _get_prompt(self) -> str:
|
72 |
if self.allow_repeat:
|
73 |
+
prompt = f"Construct a valid {self.num_chars}-character English word from the following letters:\n{self.total_chars}.\nEach character can be used multiple times. Please write None if there is no valid combination. Print only the answer.\n"
|
74 |
else:
|
75 |
+
prompt = f"Construct a valid {self.num_chars}-character English word from the following letters:\n{self.total_chars}.\nEach character can only be used once. Please write None if there is no valid combination. Print only the answer.\n"
|
76 |
return prompt
|
77 |
|
78 |
def _validate(self, answer: str) -> (bool, str):
|
79 |
+
if self.possible_ans != "" and answer == "None":
|
|
|
80 |
val_msg = "There is a valid answer."
|
81 |
return False, val_msg
|
82 |
+
answer = answer.lower()
|
83 |
if len(answer) != self.num_chars:
|
84 |
val_msg = f"Your answer must be exactly {self.num_chars} characters long"
|
85 |
return False, val_msg
|
|
|
87 |
if char not in self.total_chars:
|
88 |
val_msg = "Your answer must only contain the characters provided"
|
89 |
return False, val_msg
|
90 |
+
# if (not self.allow_repeat and (len(set(answer)) != len(answer))
|
91 |
+
# and (len(self.possible_ans) == len(set(self.possible_ans)))):
|
92 |
+
if not self.allow_repeat:
|
93 |
+
_ans = sorted(answer)
|
94 |
+
_chars = sorted(self.total_chars)
|
95 |
+
j, k = 0, 0
|
96 |
+
while j < len(_ans) and k < len(_chars):
|
97 |
+
if _ans[j] == _chars[k]:
|
98 |
+
j += 1
|
99 |
+
k += 1
|
100 |
+
if j < len(_ans):
|
101 |
+
val_msg = "Your answer must not contain repeated characters"
|
102 |
+
return False, val_msg
|
103 |
if answer not in self.WORD_LIST_BIN[str(self.num_chars)]:
|
104 |
val_msg = "Your answer is not a valid English word"
|
105 |
return False, val_msg
|
106 |
|
107 |
return True, ""
|
108 |
+
|
109 |
+
@staticmethod
|
110 |
+
def example() -> (str, str):
|
111 |
+
prompt = ("Construct a valid 5-character English word from the following letters:\n"
|
112 |
+
"['e', 'l', 'o', 'b', 's', 'p'].\n"
|
113 |
+
"Each character can be used multiple times. Please write None if there is no valid combination."
|
114 |
+
" Print only the answer.\n")
|
115 |
+
answer = "sleep"
|
116 |
+
return prompt, answer
|
117 |
+
|
textgames/bracket_game/bracket_game.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import random
|
2 |
import re
|
|
|
3 |
from pathlib import Path
|
4 |
from textgames.base_game import BaseGame
|
5 |
#%%
|
@@ -57,48 +58,94 @@ class BracketGame(BaseGame):
|
|
57 |
self.MULTI_WORD_LIST.append(self.WORD_LIST[num1] + self.WORD_LIST[num2])
|
58 |
|
59 |
def _validate(self, answer: str) -> (bool, str):
|
60 |
-
|
61 |
-
arr = answer.split(rule[0])
|
62 |
-
|
63 |
-
if rule[1][1] not in arr[0] or rule[1][2] not in arr[1]:
|
64 |
-
val_msg = f"{rule[0]} is not between the correct bracket, {rule[1][1]} not in {arr[0]} and {rule[1][2]} not in {arr[1]}"
|
65 |
-
return False, val_msg
|
66 |
-
|
67 |
-
filter_answer = answer
|
68 |
-
for i in range(0, 26):
|
69 |
-
cc = chr(ord("a") + i)
|
70 |
-
filter_answer = filter_answer.replace(cc,"")
|
71 |
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
}
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
93 |
val_msg = "There is a closing bracket without an open bracket"
|
94 |
return False, val_msg
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
val_msg = f"The depth of the bracket is {
|
100 |
return False, val_msg
|
101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
def _generate_new_game(self, *args, **kwargs) -> None:
|
103 |
num_words = kwargs["num_words"]
|
104 |
num_rules = kwargs["num_rules"]
|
@@ -141,6 +188,7 @@ class BracketGame(BaseGame):
|
|
141 |
prompt = f"You are given a text {self.string} Your job is to put some valid parenthesis brackets in the text such that:\n"
|
142 |
for rule in self.rules:
|
143 |
prompt += f"- \"{rule[0]}\" is inside a {rule[1][0]} bracket\n"
|
|
|
144 |
prompt += f"The bracket depth must be {self.depth} and print only the answer\n"
|
145 |
return prompt
|
146 |
|
@@ -159,7 +207,7 @@ class BracketGame(BaseGame):
|
|
159 |
else:
|
160 |
return 0
|
161 |
|
162 |
-
content = state_string.split("the text such that:")[1].split("\nThe
|
163 |
|
164 |
self.words = []
|
165 |
self.rules = []
|
@@ -188,3 +236,14 @@ class BracketGame(BaseGame):
|
|
188 |
self.create_multiple_words()
|
189 |
|
190 |
sort_game_states(self)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import random
|
2 |
import re
|
3 |
+
from bisect import bisect_left
|
4 |
from pathlib import Path
|
5 |
from textgames.base_game import BaseGame
|
6 |
#%%
|
|
|
58 |
self.MULTI_WORD_LIST.append(self.WORD_LIST[num1] + self.WORD_LIST[num2])
|
59 |
|
60 |
def _validate(self, answer: str) -> (bool, str):
|
61 |
+
answer = "".join(answer.split()).lower()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
+
if ("".join(filter(lambda a: a.isalpha(), answer)) !=
|
64 |
+
"".join(filter(lambda a: a.isalpha(), self.string.lower()))):
|
65 |
+
val_msg = f"You are not allowed to change the character sequence of base text '{self.string}'."
|
66 |
+
return False, val_msg
|
67 |
+
|
68 |
+
char2type_op = {b[1]: b[0] for b in self.BRACKETS}
|
69 |
+
char2type_ed = {b[2]: b[0] for b in self.BRACKETS}
|
70 |
+
|
71 |
+
depth_count = {b[0]: [(-1, 0)] for b in self.BRACKETS}
|
72 |
+
|
73 |
+
def push(dc, v):
|
74 |
+
cur_depth = dc[-1][-1]
|
75 |
+
if cur_depth < 0:
|
76 |
+
return False
|
77 |
+
dc.append((i, cur_depth + v))
|
78 |
+
return True
|
79 |
+
|
80 |
+
mak, cur_mak = 0, 0
|
81 |
+
for i, c in enumerate(answer):
|
82 |
+
if c in char2type_op:
|
83 |
+
push(depth_count[char2type_op[c]], 1)
|
84 |
+
cur_mak += 1
|
85 |
+
elif c in char2type_ed:
|
86 |
+
if not push(depth_count[char2type_ed[c]], -1):
|
87 |
val_msg = "There is a closing bracket without an open bracket"
|
88 |
return False, val_msg
|
89 |
+
cur_mak -= 1
|
90 |
+
mak = max(mak, cur_mak)
|
91 |
+
|
92 |
+
if mak != self.depth:
|
93 |
+
val_msg = f"The depth of the bracket is {mak}. The expected depth is {self.depth}"
|
94 |
return False, val_msg
|
95 |
|
96 |
+
for rule in self.rules:
|
97 |
+
i = answer.find(rule[0])
|
98 |
+
if i < 0:
|
99 |
+
val_msg = f"The text '{rule[0]}' is not found in your answer."
|
100 |
+
return False, val_msg
|
101 |
+
|
102 |
+
i_depth = bisect_left(depth_count[rule[1][0]], (i, -1)) - 1
|
103 |
+
if depth_count[rule[1][0]][i_depth][-1] < 1:
|
104 |
+
val_msg = f"The text '{rule[0]}' is not inside any {rule[1][0]} bracket {rule[1][1]} {rule[1][2]}"
|
105 |
+
return False, val_msg
|
106 |
+
|
107 |
+
# arr = answer.split(rule[0])
|
108 |
+
# if rule[1][1] not in arr[0] or rule[1][2] not in arr[1]:
|
109 |
+
# val_msg = f"The text '{rule[0]}' is not between the correct bracket, {rule[1][1]} not in {arr[0]} and {rule[1][2]} not in {arr[1]}"
|
110 |
+
# return False, val_msg
|
111 |
+
|
112 |
+
return True, ""
|
113 |
+
|
114 |
+
# filter_answer = answer
|
115 |
+
# for i in range(0, 26):
|
116 |
+
# cc = chr(ord("a") + i)
|
117 |
+
# filter_answer = filter_answer.replace(cc,"")
|
118 |
+
#
|
119 |
+
# cc = chr(ord("A") + i)
|
120 |
+
# filter_answer = filter_answer.replace(cc,"")
|
121 |
+
#
|
122 |
+
# open_bracket_list = ["[", "{", "(", "<"]
|
123 |
+
# close_bracket_map = {
|
124 |
+
# "[":"]", "{":"}", "(":")", "<":">"
|
125 |
+
# }
|
126 |
+
#
|
127 |
+
# # check max depth
|
128 |
+
# count = 0
|
129 |
+
# st = []
|
130 |
+
#
|
131 |
+
# for i in range(len(filter_answer)):
|
132 |
+
# if (filter_answer[i] in open_bracket_list):
|
133 |
+
# st.append(filter_answer[i]) # pushing the bracket in the stack
|
134 |
+
# else:
|
135 |
+
# if len(st) > 0 and (filter_answer[i] == close_bracket_map[st[-1]]):
|
136 |
+
# if (count < len(st)):
|
137 |
+
# count = len(st)
|
138 |
+
# st.pop()
|
139 |
+
# else:
|
140 |
+
# val_msg = "There is a closing bracket without an open bracket"
|
141 |
+
# return False, val_msg
|
142 |
+
#
|
143 |
+
# if count == self.depth:
|
144 |
+
# return True, ""
|
145 |
+
# else:
|
146 |
+
# val_msg = f"The depth of the bracket is {count}. The expected depth is {self.depth}"
|
147 |
+
# return False, val_msg
|
148 |
+
|
149 |
def _generate_new_game(self, *args, **kwargs) -> None:
|
150 |
num_words = kwargs["num_words"]
|
151 |
num_rules = kwargs["num_rules"]
|
|
|
188 |
prompt = f"You are given a text {self.string} Your job is to put some valid parenthesis brackets in the text such that:\n"
|
189 |
for rule in self.rules:
|
190 |
prompt += f"- \"{rule[0]}\" is inside a {rule[1][0]} bracket\n"
|
191 |
+
prompt += "The open and close parenthesis for block is [ ], curly is { }, round is ( ), and angle is < >\n"
|
192 |
prompt += f"The bracket depth must be {self.depth} and print only the answer\n"
|
193 |
return prompt
|
194 |
|
|
|
207 |
else:
|
208 |
return 0
|
209 |
|
210 |
+
content = state_string.split("the text such that:")[1].split("\nThe open and close parenthesis ")[0].split("\n")
|
211 |
|
212 |
self.words = []
|
213 |
self.rules = []
|
|
|
236 |
self.create_multiple_words()
|
237 |
|
238 |
sort_game_states(self)
|
239 |
+
|
240 |
+
@staticmethod
|
241 |
+
def example() -> (str, str):
|
242 |
+
prompt = ("You are given a text fabuloustextgames Your job is to put some valid parenthesis brackets in the text such that:\n"
|
243 |
+
"- \"games\" is inside a round bracket\n"
|
244 |
+
"- \"text\" is inside a angle bracket\n"
|
245 |
+
"- \"fabulous\" is inside a block bracket\n"
|
246 |
+
"The open and close parenthesis for block is [ ], curly is { }, round is ( ), and angle is < >\n"
|
247 |
+
"The bracket depth must be 2 and print only the answer\n")
|
248 |
+
answer = "[[fabulous]<text>(games)]"
|
249 |
+
return prompt, answer
|
textgames/crossword_arranger/crossword_arranger.py
CHANGED
@@ -125,19 +125,47 @@ class CrosswordArrangerGame(BaseGame):
|
|
125 |
return prompt
|
126 |
|
127 |
def _validate(self, answer: str) -> (bool, str):
|
128 |
-
|
|
|
|
|
129 |
val_msg = ""
|
|
|
|
|
|
|
|
|
130 |
if len(ans_hor) != self.board_size:
|
131 |
val_msg = f"Mismatch answer length found!! Expected size of {self.board_size}, got {len(ans_hor)}."
|
132 |
return False, val_msg
|
|
|
|
|
|
|
|
|
133 |
ans_ver = [''.join(ans_hor[r][c] for r in range(self.board_size)) for c in range(self.board_size)]
|
134 |
word_set = set(self.word_list)
|
135 |
-
for w in chain(ans_hor, ans_ver):
|
136 |
if w not in word_set:
|
|
|
|
|
137 |
return False, val_msg
|
138 |
word_set.remove(w)
|
139 |
return True, val_msg
|
140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
|
142 |
#%%
|
143 |
|
|
|
125 |
return prompt
|
126 |
|
127 |
def _validate(self, answer: str) -> (bool, str):
|
128 |
+
answer = answer if answer else ""
|
129 |
+
# ans_hor = list(filter(None, answer.lower().replace(' ', '\n').split("\n")))
|
130 |
+
ans_hor = answer.lower().split()
|
131 |
val_msg = ""
|
132 |
+
if len(ans_hor) != self.board_size:
|
133 |
+
arr = answer.lower().split()
|
134 |
+
if all(len(l) == 1 for l in arr) and (len(arr) == self.board_size * self.board_size):
|
135 |
+
ans_hor = ["".join(arr[i:i+self.board_size]) for i in range(0, len(arr), self.board_size)]
|
136 |
if len(ans_hor) != self.board_size:
|
137 |
val_msg = f"Mismatch answer length found!! Expected size of {self.board_size}, got {len(ans_hor)}."
|
138 |
return False, val_msg
|
139 |
+
for w in ans_hor:
|
140 |
+
if len(w) != self.board_size:
|
141 |
+
val_msg = f"Mismatch answer length found!! Expected size of {self.board_size}, got {len(w)}."
|
142 |
+
return False, val_msg
|
143 |
ans_ver = [''.join(ans_hor[r][c] for r in range(self.board_size)) for c in range(self.board_size)]
|
144 |
word_set = set(self.word_list)
|
145 |
+
for i, w in enumerate(chain(ans_hor, ans_ver)):
|
146 |
if w not in word_set:
|
147 |
+
val_msg = (f"Mismatch answer word found!! {'Horizontal' if i < self.board_size else 'Vertical'} word"
|
148 |
+
f" '{w}' is not in the word set.")
|
149 |
return False, val_msg
|
150 |
word_set.remove(w)
|
151 |
return True, val_msg
|
152 |
|
153 |
+
@staticmethod
|
154 |
+
def example() -> (str, str):
|
155 |
+
prompt = (f"Given a board size of 3x3, arrange a possible crossword puzzle answer from a list of words.\n"
|
156 |
+
f"Item in the list can only be used once.\n\n"
|
157 |
+
f"List of words:\n"
|
158 |
+
f"- app\n"
|
159 |
+
f"- all\n"
|
160 |
+
f"- and\n"
|
161 |
+
f"- lee\n"
|
162 |
+
f"- let\n"
|
163 |
+
f"- pat\n"
|
164 |
+
f"- pee\n"
|
165 |
+
f"- pet\n\n"
|
166 |
+
f"Print only the answer.")
|
167 |
+
answer = "app\nlee\nlet"
|
168 |
+
return prompt, answer
|
169 |
|
170 |
#%%
|
171 |
|
textgames/islands/islands.py
CHANGED
@@ -99,8 +99,8 @@ class Islands(BaseGame):
|
|
99 |
answer = [a.replace(" ", "").lower().strip() for a in answer]
|
100 |
|
101 |
# check the size
|
102 |
-
if len(answer) != self.N or len(
|
103 |
-
val_msg = f"2D grid is not {self.N} x {self.N}. ({len(answer)} x {len(answer
|
104 |
return False, val_msg
|
105 |
|
106 |
# check the tiles, ensure they are valid
|
@@ -194,4 +194,16 @@ Your 2D grid must follow the following rules:
|
|
194 |
|
195 |
Print only the answer.
|
196 |
"""
|
197 |
-
return prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
answer = [a.replace(" ", "").lower().strip() for a in answer]
|
100 |
|
101 |
# check the size
|
102 |
+
if len(answer) != self.N or any((len(a) < self.N) for a in answer):
|
103 |
+
val_msg = f"2D grid is not {self.N} x {self.N}. ({len(answer)} x {set(len(a) for a in answer)})"
|
104 |
return False, val_msg
|
105 |
|
106 |
# check the tiles, ensure they are valid
|
|
|
194 |
|
195 |
Print only the answer.
|
196 |
"""
|
197 |
+
return prompt
|
198 |
+
|
199 |
+
@staticmethod
|
200 |
+
def example() -> (str, str):
|
201 |
+
prompt = ("You are asked to construct a 2D 5 x 5 grid, consisting of water tiles (denoted by \u2019.\u2019), \n"
|
202 |
+
"land tiles (denoted by \u2019#\u2019). \n\n"
|
203 |
+
"A group of connected land tiles in 4 cardinal directions forms an island.\n\n"
|
204 |
+
"Your 2D grid must follow the following rules:\n"
|
205 |
+
"- There must be exactly 1 islands.\n"
|
206 |
+
"- The size of each island must be from 1 to 2 tiles.\n\n"
|
207 |
+
"Print only the answer.\n")
|
208 |
+
answer = "...##\n.....\n.....\n.....\n....."
|
209 |
+
return prompt, answer
|
textgames/ordering_text/ordering_text.py
CHANGED
@@ -5,10 +5,10 @@ Rules Description
|
|
5 |
|
6 |
word length:
|
7 |
- example: word less than 5 characters gets 10 points
|
8 |
-
- possible operands: {
|
9 |
-
-
|
10 |
-
- possible combinations: {
|
11 |
-
- only 1
|
12 |
|
13 |
neighboring / consecutive chars
|
14 |
- example: every pair of consecutive consonant gets 5 points
|
@@ -66,6 +66,15 @@ from textgames.base_game import BaseGame
|
|
66 |
from textgames.assets.word_list import WORDS_LIST, WORDS_BY_LEN
|
67 |
|
68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
#%%
|
70 |
class Scoring:
|
71 |
def __init__(self, point: int):
|
@@ -505,14 +514,15 @@ class OrderingTextGame(BaseGame):
|
|
505 |
return self.answer # sorted(self.words, key=lambda word: (self.get_point(word), word))
|
506 |
|
507 |
def _validate(self, answer: str) -> (bool, str):
|
508 |
-
answer = answer.lower().replace('
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
|
|
516 |
|
517 |
def _generate_new_game(self, *args, **kwargs) -> None:
|
518 |
if "preset_config" in kwargs:
|
@@ -588,6 +598,26 @@ class OrderingTextGame(BaseGame):
|
|
588 |
prompt += "\nPrint only the answer."
|
589 |
return prompt
|
590 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
591 |
|
592 |
#%%
|
593 |
|
|
|
5 |
|
6 |
word length:
|
7 |
- example: word less than 5 characters gets 10 points
|
8 |
+
- possible operands: {\\eq, \\lt, \\gt, \\ne}
|
9 |
+
- \\le and \\ge will be randomized for prompt generation
|
10 |
+
- possible combinations: {\\gt\\lt, \\gt\\lt\\ne}
|
11 |
+
- only 1 \\ne is considered
|
12 |
|
13 |
neighboring / consecutive chars
|
14 |
- example: every pair of consecutive consonant gets 5 points
|
|
|
66 |
from textgames.assets.word_list import WORDS_LIST, WORDS_BY_LEN
|
67 |
|
68 |
|
69 |
+
#%%
|
70 |
+
index_to_word = {
|
71 |
+
1: "first", 2: "second", 3: "third", 4: "fourth", 5: "fifth",
|
72 |
+
6: "sixth", 7: "seventh", 8: "eighth", 9: "ninth", 10: "tenth",
|
73 |
+
11: "eleventh", 12: "twelfth", 13: "thirteenth", 14: "fourteenth",
|
74 |
+
15: "fifteenth", 16: "sixteenth", 17: "seventeenth", 18: "eighteenth",
|
75 |
+
}
|
76 |
+
|
77 |
+
|
78 |
#%%
|
79 |
class Scoring:
|
80 |
def __init__(self, point: int):
|
|
|
514 |
return self.answer # sorted(self.words, key=lambda word: (self.get_point(word), word))
|
515 |
|
516 |
def _validate(self, answer: str) -> (bool, str):
|
517 |
+
answer = answer.lower().replace(',', ' ').split()
|
518 |
+
gold = self.get_answer()
|
519 |
+
if len(answer) < len(gold):
|
520 |
+
return False, f"Your answer is too short. There should be {len(gold)} items."
|
521 |
+
for i, (a, b) in enumerate(zip(answer, self.get_answer()), 1):
|
522 |
+
if a != b:
|
523 |
+
val_msg = f"'{a}' is not supposed to be the {index_to_word[i]} word in the order."
|
524 |
+
return False, val_msg
|
525 |
+
return True, ""
|
526 |
|
527 |
def _generate_new_game(self, *args, **kwargs) -> None:
|
528 |
if "preset_config" in kwargs:
|
|
|
598 |
prompt += "\nPrint only the answer."
|
599 |
return prompt
|
600 |
|
601 |
+
@staticmethod
|
602 |
+
def example() -> (str, str):
|
603 |
+
prompt = ("Given a set of rules to calculate point, sort the set of words in decreasing order.\n"
|
604 |
+
"When there 2 or more words with same point, sort lexicographically.\n\n"
|
605 |
+
"Rules:\n"
|
606 |
+
"- add 10 points if there exists 'u' in the word\n\n"
|
607 |
+
"Words:\n"
|
608 |
+
"- hudi\n"
|
609 |
+
"- genta\n"
|
610 |
+
"- aji\n"
|
611 |
+
"- ruochen\n\n"
|
612 |
+
"Print only the answer.")
|
613 |
+
answer = (
|
614 |
+
"hudi\n"
|
615 |
+
"ruochen\n"
|
616 |
+
"aji\n"
|
617 |
+
"genta"
|
618 |
+
)
|
619 |
+
return prompt, answer
|
620 |
+
|
621 |
|
622 |
#%%
|
623 |
|
textgames/password_game/password_game.py
CHANGED
@@ -274,3 +274,13 @@ class PasswordGame(BaseGame):
|
|
274 |
self.rules = [rule for rule in new_rules]
|
275 |
|
276 |
sort_game_states(self)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
self.rules = [rule for rule in new_rules]
|
275 |
|
276 |
sort_game_states(self)
|
277 |
+
|
278 |
+
@staticmethod
|
279 |
+
def example() -> (str, str):
|
280 |
+
prompt = ("Please write a text string without any space by following a set of given rules."
|
281 |
+
" Please write only the answer and follow the following criteria:\n"
|
282 |
+
"- the text has 6 english character\n"
|
283 |
+
"- the text has 0 uppercase characters\n")
|
284 |
+
answer = "hoodie"
|
285 |
+
return prompt, answer
|
286 |
+
|
textgames/string_search/string_search.py
CHANGED
@@ -309,4 +309,16 @@ Find a substring of exactly {self.answer_len} characters long that:
|
|
309 |
{extra_constraints}
|
310 |
Print only the answer.
|
311 |
"""
|
312 |
-
return prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
309 |
{extra_constraints}
|
310 |
Print only the answer.
|
311 |
"""
|
312 |
+
return prompt
|
313 |
+
|
314 |
+
@staticmethod
|
315 |
+
def example() -> (str, str):
|
316 |
+
prompt = ("You are given the following string:\n"
|
317 |
+
"hudigentaajiruochen\n\n"
|
318 |
+
"Find a substring of exactly 3 characters long that:\n"
|
319 |
+
" - Contains t\n"
|
320 |
+
" - Does not contain i and a\n\n"
|
321 |
+
"Print only the answer.\n")
|
322 |
+
answer = "ent"
|
323 |
+
return prompt, answer
|
324 |
+
|
textgames/sudoku/sudoku.py
CHANGED
@@ -9,6 +9,7 @@ Please solve the 9x9 sudoku with 1,2,3,4,5,6,7,8,9 as the values and fill _ with
|
|
9 |
Print only the answer.
|
10 |
"""
|
11 |
|
|
|
12 |
#%%
|
13 |
class Sudoku(BaseGame):
|
14 |
@staticmethod
|
@@ -28,34 +29,47 @@ class Sudoku(BaseGame):
|
|
28 |
for j in range(self.size):
|
29 |
num = mat[i][j]
|
30 |
if num == self.empty_character:
|
31 |
-
return False
|
32 |
|
33 |
subgrid_index = (i // self.srn) * self.srn + (j // self.srn)
|
34 |
|
35 |
-
if num in rows[i]
|
36 |
-
return False
|
37 |
-
|
|
|
|
|
|
|
|
|
38 |
rows[i].add(num)
|
39 |
cols[j].add(num)
|
40 |
subgrids[subgrid_index].add(num)
|
41 |
|
42 |
-
return True
|
43 |
|
44 |
def _validate(self, input) -> (bool, str):
|
45 |
mat = [[self.empty_character for i in range(self.size)] for j in range(self.size)]
|
46 |
|
|
|
47 |
arr = input.split()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
for i in range(len(arr)):
|
49 |
for j in range(len(arr[i])):
|
50 |
if arr[i][j] not in self.char_to_id:
|
51 |
-
val_msg = "
|
52 |
return False, val_msg
|
53 |
|
54 |
mat[i][j] = self.char_to_id[arr[i][j]]
|
55 |
if arr[i][j] != self.mat[i][j] and self.mat[i][j] != self.empty_character:
|
56 |
val_msg = "One or more characters are replaced"
|
57 |
return False, val_msg
|
58 |
-
|
|
|
59 |
|
60 |
def _generate_new_game(self, *args, **kwargs) -> None:
|
61 |
size=kwargs["size"]
|
@@ -228,4 +242,14 @@ class Sudoku(BaseGame):
|
|
228 |
self.char_to_id = {}
|
229 |
for c_id in range(len(self.characters)):
|
230 |
self.char_to_id[self.characters[c_id]] = c_id
|
231 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
Print only the answer.
|
10 |
"""
|
11 |
|
12 |
+
|
13 |
#%%
|
14 |
class Sudoku(BaseGame):
|
15 |
@staticmethod
|
|
|
29 |
for j in range(self.size):
|
30 |
num = mat[i][j]
|
31 |
if num == self.empty_character:
|
32 |
+
return False, "There are unfilled cells"
|
33 |
|
34 |
subgrid_index = (i // self.srn) * self.srn + (j // self.srn)
|
35 |
|
36 |
+
if num in rows[i]:
|
37 |
+
return False, f"Duplicated row value ({num}) for cell in row {i+1} column {j+1}."
|
38 |
+
elif num in cols[j]:
|
39 |
+
return False, f"Duplicated column value ({num}) for cell in row {i+1} column {j+1}."
|
40 |
+
elif num in subgrids[subgrid_index]:
|
41 |
+
return False, f"Duplicated subgrid value ({num}) for cell in row {i+1} column {j+1}."
|
42 |
+
|
43 |
rows[i].add(num)
|
44 |
cols[j].add(num)
|
45 |
subgrids[subgrid_index].add(num)
|
46 |
|
47 |
+
return True, ""
|
48 |
|
49 |
def _validate(self, input) -> (bool, str):
|
50 |
mat = [[self.empty_character for i in range(self.size)] for j in range(self.size)]
|
51 |
|
52 |
+
input = input if input else ""
|
53 |
arr = input.split()
|
54 |
+
if all(len(l) == 1 for l in arr) and (len(arr) == self.size * self.size):
|
55 |
+
arr = ["".join(arr[i:i+self.size]) for i in range(0, len(arr), self.size)]
|
56 |
+
if (len(arr) != self.size) or any(len(arr[i]) != self.size for i in range(len(arr))):
|
57 |
+
arr = input.split("\n")
|
58 |
+
val_msg = f"Your answer is wrong in shape, it should be {self.size}x{self.size} sudoku."
|
59 |
+
return False, val_msg
|
60 |
+
|
61 |
for i in range(len(arr)):
|
62 |
for j in range(len(arr[i])):
|
63 |
if arr[i][j] not in self.char_to_id:
|
64 |
+
val_msg = "There are unrecognized characters, or possibly unfilled cells."
|
65 |
return False, val_msg
|
66 |
|
67 |
mat[i][j] = self.char_to_id[arr[i][j]]
|
68 |
if arr[i][j] != self.mat[i][j] and self.mat[i][j] != self.empty_character:
|
69 |
val_msg = "One or more characters are replaced"
|
70 |
return False, val_msg
|
71 |
+
|
72 |
+
return self.is_valid_sudoku(mat)
|
73 |
|
74 |
def _generate_new_game(self, *args, **kwargs) -> None:
|
75 |
size=kwargs["size"]
|
|
|
242 |
self.char_to_id = {}
|
243 |
for c_id in range(len(self.characters)):
|
244 |
self.char_to_id[self.characters[c_id]] = c_id
|
245 |
+
|
246 |
+
@staticmethod
|
247 |
+
def example() -> (str, str):
|
248 |
+
prompt = ("Please solve the 4x4 sudoku with A,B,C,D as the values and fill _ with the possible value and"
|
249 |
+
" only print the answer. Follow the sudoku rule.\nA_CD CD_B _AD_ DCBA")
|
250 |
+
answer = ("ABCD\n"
|
251 |
+
"CDAB\n"
|
252 |
+
"BADC\n"
|
253 |
+
"DCBA")
|
254 |
+
return prompt, answer
|
255 |
+
|