Upload folder using huggingface_hub
Browse files- .gitattributes +0 -35
- .gitignore +1 -0
- README.md +34 -21
- Try Gemma-2-9B.ipynb +1 -218
- agents/check_param.py +9 -0
- oauth_environ_google.sh +1 -1
- play_helper.py +28 -17
- play_with_hf.py +89 -30
- textgames_check_model_outputs.py +172 -0
.gitattributes
CHANGED
@@ -1,36 +1 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
textgames-scrabble-black2-ss.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
textgames-scrabble-black2-ss.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
CHANGED
@@ -4,6 +4,7 @@
|
|
4 |
agents/*.sh
|
5 |
user_outputs/
|
6 |
model_outputs/__runs__
|
|
|
7 |
|
8 |
.idea/
|
9 |
|
|
|
4 |
agents/*.sh
|
5 |
user_outputs/
|
6 |
model_outputs/__runs__
|
7 |
+
runner_out/
|
8 |
|
9 |
.idea/
|
10 |
|
README.md
CHANGED
@@ -9,24 +9,37 @@ hf_oauth: true
|
|
9 |
---
|
10 |
# TextGames
|
11 |
|
12 |
-
##
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
---
|
10 |
# TextGames
|
11 |
|
12 |
+
## Play on Hosted Servers
|
13 |
+
|
14 |
+
- HuggingFace Space
|
15 |
+
|
16 |
+
https://huggingface.co/spaces/fhudi/textgames
|
17 |
+
(login required)
|
18 |
+
|
19 |
+
## Play on localhost
|
20 |
+
|
21 |
+
- Setup
|
22 |
+
```
|
23 |
+
❱❱❱ pip install -r requirements.txt
|
24 |
+
```
|
25 |
+
|
26 |
+
- Play (Terminal)
|
27 |
+
```
|
28 |
+
❱❱❱ python play.py
|
29 |
+
```
|
30 |
+
|
31 |
+
- Play (Web UI)
|
32 |
+
```
|
33 |
+
❱❱❱ pip install gradio
|
34 |
+
❱❱❱ GRADIO_SERVER_PORT=1080 python play_gradio.py
|
35 |
+
```
|
36 |
+
Open `localhost:1080` to access.
|
37 |
+
|
38 |
+
---
|
39 |
+
|
40 |
+
## Extras
|
41 |
+
|
42 |
+
- Optional Environment Varibles
|
43 |
+
```
|
44 |
+
TEXTGAMES_SHOW_HIDDEN_LEVEL=1
|
45 |
+
```
|
Try Gemma-2-9B.ipynb
CHANGED
@@ -1,218 +1 @@
|
|
1 |
-
{
|
2 |
-
"nbformat": 4,
|
3 |
-
"nbformat_minor": 0,
|
4 |
-
"metadata": {
|
5 |
-
"colab": {
|
6 |
-
"private_outputs": true,
|
7 |
-
"provenance": [],
|
8 |
-
"authorship_tag": "ABX9TyPmvDoFpmwAf1QFBJZy7XSQ"
|
9 |
-
},
|
10 |
-
"kernelspec": {
|
11 |
-
"name": "python3",
|
12 |
-
"display_name": "Python 3"
|
13 |
-
},
|
14 |
-
"language_info": {
|
15 |
-
"name": "python"
|
16 |
-
}
|
17 |
-
},
|
18 |
-
"cells": [
|
19 |
-
{
|
20 |
-
"cell_type": "code",
|
21 |
-
"execution_count": null,
|
22 |
-
"metadata": {
|
23 |
-
"id": "Rli_enT6lBDT"
|
24 |
-
},
|
25 |
-
"outputs": [],
|
26 |
-
"source": [
|
27 |
-
"##%%\n",
|
28 |
-
"import os\n",
|
29 |
-
"import torch\n",
|
30 |
-
"import random\n",
|
31 |
-
"import numpy as np\n",
|
32 |
-
"import argparse\n",
|
33 |
-
"import json\n",
|
34 |
-
"import cohere\n",
|
35 |
-
"from openai import OpenAI\n"
|
36 |
-
]
|
37 |
-
},
|
38 |
-
{
|
39 |
-
"cell_type": "code",
|
40 |
-
"source": [
|
41 |
-
"##%%\n",
|
42 |
-
"from tqdm import tqdm\n",
|
43 |
-
"\n",
|
44 |
-
"from collections import Counter\n",
|
45 |
-
"\n",
|
46 |
-
"from transformers import LlamaForCausalLM, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM\n",
|
47 |
-
"import hashlib\n",
|
48 |
-
"\n",
|
49 |
-
"from textgames import GAME_NAMES, GAME_IDS, LEVELS, LEVELS_HIDDEN, LEVEL_IDS, new_game\n"
|
50 |
-
],
|
51 |
-
"metadata": {
|
52 |
-
"id": "dp1F32B8oSfD"
|
53 |
-
},
|
54 |
-
"execution_count": null,
|
55 |
-
"outputs": []
|
56 |
-
},
|
57 |
-
{
|
58 |
-
"cell_type": "code",
|
59 |
-
"source": [
|
60 |
-
"##%%\n",
|
61 |
-
"gen_model_checkpoint = \"google/gemma-2-9b-it\"\n",
|
62 |
-
"quantize = True"
|
63 |
-
],
|
64 |
-
"metadata": {
|
65 |
-
"id": "jZF8bkUcojTX"
|
66 |
-
},
|
67 |
-
"execution_count": null,
|
68 |
-
"outputs": []
|
69 |
-
},
|
70 |
-
{
|
71 |
-
"cell_type": "code",
|
72 |
-
"source": [
|
73 |
-
"kwargs = {\n",
|
74 |
-
" \"device_map\": \"auto\",\n",
|
75 |
-
"} if quantize else {}"
|
76 |
-
],
|
77 |
-
"metadata": {
|
78 |
-
"id": "VAF5sR9arYzS"
|
79 |
-
},
|
80 |
-
"execution_count": null,
|
81 |
-
"outputs": []
|
82 |
-
},
|
83 |
-
{
|
84 |
-
"cell_type": "code",
|
85 |
-
"source": [
|
86 |
-
"##%%\n",
|
87 |
-
"gen_model = AutoModelForCausalLM.from_pretrained(gen_model_checkpoint, **kwargs)\n",
|
88 |
-
"tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, **kwargs)"
|
89 |
-
],
|
90 |
-
"metadata": {
|
91 |
-
"id": "tzqldl8ooRVL"
|
92 |
-
},
|
93 |
-
"execution_count": null,
|
94 |
-
"outputs": []
|
95 |
-
},
|
96 |
-
{
|
97 |
-
"cell_type": "code",
|
98 |
-
"source": [
|
99 |
-
"gen_model.device"
|
100 |
-
],
|
101 |
-
"metadata": {
|
102 |
-
"id": "FeBUXdkWsWrL"
|
103 |
-
},
|
104 |
-
"execution_count": null,
|
105 |
-
"outputs": []
|
106 |
-
},
|
107 |
-
{
|
108 |
-
"cell_type": "code",
|
109 |
-
"source": [
|
110 |
-
"def get_gemma_response(text):\n",
|
111 |
-
" # global gen_model, tokenizer\n",
|
112 |
-
" messages = [\n",
|
113 |
-
" {\"role\": \"user\", \"content\": text},\n",
|
114 |
-
" ]\n",
|
115 |
-
"\n",
|
116 |
-
" input_ids = tokenizer.apply_chat_template(\n",
|
117 |
-
" messages,\n",
|
118 |
-
" add_generation_prompt=True,\n",
|
119 |
-
" return_tensors=\"pt\"\n",
|
120 |
-
" ).to(gen_model.device)\n",
|
121 |
-
"\n",
|
122 |
-
" terminators = [\n",
|
123 |
-
" tokenizer.eos_token_id,\n",
|
124 |
-
" tokenizer.convert_tokens_to_ids(\"<|eot_id|>\")\n",
|
125 |
-
" ]\n",
|
126 |
-
"\n",
|
127 |
-
" outputs = gen_model.generate(\n",
|
128 |
-
" input_ids,\n",
|
129 |
-
" max_new_tokens=100,\n",
|
130 |
-
" eos_token_id=terminators,\n",
|
131 |
-
" do_sample=True,\n",
|
132 |
-
" temperature=0.2,\n",
|
133 |
-
" top_p=1\n",
|
134 |
-
" )\n",
|
135 |
-
"\n",
|
136 |
-
" response = outputs[0][input_ids.shape[-1]:]\n",
|
137 |
-
" return tokenizer.decode(response, skip_special_tokens=True)"
|
138 |
-
],
|
139 |
-
"metadata": {
|
140 |
-
"id": "R5D4K-P2sPaj"
|
141 |
-
},
|
142 |
-
"execution_count": null,
|
143 |
-
"outputs": []
|
144 |
-
},
|
145 |
-
{
|
146 |
-
"cell_type": "code",
|
147 |
-
"source": [
|
148 |
-
"text = \\\n",
|
149 |
-
"\"\"\"\n",
|
150 |
-
"Given a set of rules to calculate point, sort the set of words in decreasing order.\n",
|
151 |
-
"When there 2 or more words with same point, sort lexicographically.\n",
|
152 |
-
"\n",
|
153 |
-
"Rules:\n",
|
154 |
-
"- every pair of consecutive consonant gets 5 points\n",
|
155 |
-
"- every pair of consecutive vowel gets 3 points\n",
|
156 |
-
"- add 1 point if there exists exactly 1 'g' in the word\n",
|
157 |
-
"- word less than 5 characters gets 10 points\n",
|
158 |
-
"- word starts with gen gets 100 points\n",
|
159 |
-
"- word ends with ta gets -1000 point\n",
|
160 |
-
"\n",
|
161 |
-
"Words:\n",
|
162 |
-
"- genta\n",
|
163 |
-
"- winata\n",
|
164 |
-
"- hudi\n",
|
165 |
-
"- alham\n",
|
166 |
-
"- aji\n",
|
167 |
-
"- ruochen\n",
|
168 |
-
"\n",
|
169 |
-
"Print only the answer.\n",
|
170 |
-
"\"\"\"\n",
|
171 |
-
"\n",
|
172 |
-
"# Answer:\n",
|
173 |
-
"# - aji 10\n",
|
174 |
-
"# - hudi 10\n",
|
175 |
-
"# - ruochen 5 3\n",
|
176 |
-
"# - alham 5\n",
|
177 |
-
"# - genta 5 1 100 -1000\n",
|
178 |
-
"# - winata -1000"
|
179 |
-
],
|
180 |
-
"metadata": {
|
181 |
-
"id": "T_tk4hTGsxsR"
|
182 |
-
},
|
183 |
-
"execution_count": null,
|
184 |
-
"outputs": []
|
185 |
-
},
|
186 |
-
{
|
187 |
-
"cell_type": "code",
|
188 |
-
"source": [
|
189 |
-
"print(get_gemma_response(text))"
|
190 |
-
],
|
191 |
-
"metadata": {
|
192 |
-
"id": "05OI36v6vGoY"
|
193 |
-
},
|
194 |
-
"execution_count": null,
|
195 |
-
"outputs": []
|
196 |
-
},
|
197 |
-
{
|
198 |
-
"cell_type": "code",
|
199 |
-
"source": [
|
200 |
-
"print(get_gemma_response(text))"
|
201 |
-
],
|
202 |
-
"metadata": {
|
203 |
-
"id": "riwXqTc-tmNr"
|
204 |
-
},
|
205 |
-
"execution_count": null,
|
206 |
-
"outputs": []
|
207 |
-
},
|
208 |
-
{
|
209 |
-
"cell_type": "code",
|
210 |
-
"source": [],
|
211 |
-
"metadata": {
|
212 |
-
"id": "T72sUG4_vYUa"
|
213 |
-
},
|
214 |
-
"execution_count": null,
|
215 |
-
"outputs": []
|
216 |
-
}
|
217 |
-
]
|
218 |
-
}
|
|
|
1 |
+
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":1,"metadata":{"id":"Rli_enT6lBDT","executionInfo":{"status":"ok","timestamp":1737395007014,"user_tz":-540,"elapsed":5212,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}}},"outputs":[],"source":["##%%\n","import os\n","import pickle\n","import json\n","# import random\n","# import torch\n","# import numpy as np\n","# import argparse\n","# import cohere\n","# from openai import OpenAI\n"]},{"cell_type":"code","source":["##%%\n","# import hashlib\n","from tqdm import tqdm\n","from itertools import product\n","# from collections import Counter\n","\n","# from transformers import LlamaForCausalLM, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM\n","from transformers import AutoTokenizer, AutoModelForCausalLM\n","from textgames import GAME_NAMES, LEVEL_IDS, game_filename, _game_class_from_name\n"],"metadata":{"id":"dp1F32B8oSfD","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1737395010583,"user_tz":-540,"elapsed":3547,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}},"outputId":"e9adeb5f-70eb-4ca9-dcbb-428e4b28ab41"},"execution_count":2,"outputs":[{"output_type":"stream","name":"stderr","text":["/home/is/frederikus-h/miniconda3/envs/textgame/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n"," from .autonotebook import tqdm as notebook_tqdm\n"]}]},{"cell_type":"code","source":["os.environ.setdefault(\"TEXTGAMES_OUTPUT_DIR\", \"user_outputs\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"2wEu1V1wvxn0","executionInfo":{"status":"ok","timestamp":1737395010664,"user_tz":-540,"elapsed":67,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}},"outputId":"cdcad20f-e357-4009-9f4f-0d4495ebd894"},"execution_count":3,"outputs":[{"output_type":"execute_result","data":{"text/plain":["'user_outputs'"]},"metadata":{},"execution_count":3}]},{"cell_type":"code","source":["##%%\n","gen_model_checkpoint = \"google/gemma-2-9b-it\"\n","quantize = True"],"metadata":{"id":"jZF8bkUcojTX","executionInfo":{"status":"ok","timestamp":1737395010678,"user_tz":-540,"elapsed":13,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}}},"execution_count":4,"outputs":[]},{"cell_type":"code","source":["kwargs = {\n"," \"device_map\": \"auto\",\n","} if quantize else {}"],"metadata":{"id":"VAF5sR9arYzS","executionInfo":{"status":"ok","timestamp":1737395010683,"user_tz":-540,"elapsed":2,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}}},"execution_count":5,"outputs":[]},{"cell_type":"code","source":["##%%\n","gen_model = AutoModelForCausalLM.from_pretrained(gen_model_checkpoint, **kwargs)\n","tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, **kwargs)"],"metadata":{"id":"tzqldl8ooRVL","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1737395038547,"user_tz":-540,"elapsed":27859,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}},"outputId":"902b638c-e6ce-4f8a-bba2-e9f7241c9a27"},"execution_count":6,"outputs":[{"output_type":"stream","name":"stderr","text":["Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:24<00:00, 6.19s/it]\n"]}]},{"cell_type":"code","source":["gen_model.device"],"metadata":{"id":"FeBUXdkWsWrL","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1737395038552,"user_tz":-540,"elapsed":3,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}},"outputId":"6437d1b7-02f8-47f5-d519-e979cefde795"},"execution_count":7,"outputs":[{"output_type":"execute_result","data":{"text/plain":["device(type='cuda', index=0)"]},"metadata":{},"execution_count":7}]},{"cell_type":"code","source":["def get_gemma_response(text):\n"," # global gen_model, tokenizer\n"," messages = [\n"," {\"role\": \"user\", \"content\": text},\n"," ]\n","\n"," input_ids = tokenizer.apply_chat_template(\n"," messages,\n"," add_generation_prompt=True,\n"," return_tensors=\"pt\"\n"," ).to(gen_model.device)\n","\n"," terminators = [\n"," tokenizer.eos_token_id,\n"," tokenizer.convert_tokens_to_ids(\"<|eot_id|>\")\n"," ]\n","\n"," outputs = gen_model.generate(\n"," input_ids,\n"," max_new_tokens=100,\n"," eos_token_id=terminators,\n"," do_sample=True,\n"," temperature=.001,\n"," top_p=1,\n"," )\n","\n"," response = outputs[0][input_ids.shape[-1]:]\n"," return tokenizer.decode(response, skip_special_tokens=True)"],"metadata":{"id":"R5D4K-P2sPaj","executionInfo":{"status":"ok","timestamp":1737395038554,"user_tz":-540,"elapsed":1,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}}},"execution_count":8,"outputs":[]},{"cell_type":"markdown","source":["---\n","Example Call"],"metadata":{"id":"s5FEwOOvxf4h"}},{"cell_type":"code","source":["# @title\n","text = \\\n","\"\"\"\n","Given a set of rules to calculate point, sort the set of words in decreasing order.\n","When there 2 or more words with same point, sort lexicographically.\n","\n","Rules:\n","- every pair of consecutive consonant gets 5 points\n","- every pair of consecutive vowel gets 3 points\n","- add 1 point if there exists exactly 1 'g' in the word\n","- word less than 5 characters gets 10 points\n","- word starts with gen gets 100 points\n","- word ends with ta gets -1000 point\n","\n","Words:\n","- genta\n","- winata\n","- hudi\n","- alham\n","- aji\n","- ruochen\n","\n","Print only the answer.\n","\"\"\"\n","\n","print(text)"],"metadata":{"id":"T_tk4hTGsxsR","colab":{"base_uri":"https://localhost:8080/"},"cellView":"form","executionInfo":{"status":"ok","timestamp":1737392776367,"user_tz":-540,"elapsed":27,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}},"outputId":"d5ea884f-d0fa-4134-ecd9-690eab51c976"},"execution_count":14,"outputs":[{"output_type":"stream","name":"stdout","text":["\n","Given a set of rules to calculate point, sort the set of words in decreasing order.\n","When there 2 or more words with same point, sort lexicographically.\n","\n","Rules:\n","- every pair of consecutive consonant gets 5 points\n","- every pair of consecutive vowel gets 3 points\n","- add 1 point if there exists exactly 1 'g' in the word\n","- word less than 5 characters gets 10 points\n","- word starts with gen gets 100 points\n","- word ends with ta gets -1000 point\n","\n","Words:\n","- genta\n","- winata\n","- hudi\n","- alham\n","- aji\n","- ruochen\n","\n","Print only the answer.\n","\n"]}]},{"cell_type":"code","source":["# Gold Answer:\n","# - aji 10\n","# - hudi 10\n","# - ruochen 5 3\n","# - alham 5\n","# - genta 5 1 100 -1000\n","# - winata -1000"],"metadata":{"id":"G-5yS4S-rdsN"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["print(get_gemma_response(text))"],"metadata":{"id":"05OI36v6vGoY","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1737392724119,"user_tz":-540,"elapsed":6741,"user":{"displayName":"Frederikus Hudi","userId":"06160664103998835801"}},"outputId":"fe5d6ed2-d063-4f1c-b2e1-b3af8dbc456e"},"execution_count":9,"outputs":[{"output_type":"stream","name":"stdout","text":["genta\n","winata\n","ruochen\n","hudi\n","alham\n","aji \n","\n"]}]},{"cell_type":"markdown","source":["---\n","Automate run all sessions"],"metadata":{"id":"cxJ4WqHpxi75"}},{"cell_type":"code","source":["for game_name, difficulty_level in product([GAME_NAMES[4], *GAME_NAMES[:4], *GAME_NAMES[5:]], LEVEL_IDS[:3]):\n"," game_cls = _game_class_from_name(game_name)\n"," with open(f\"problemsets/{game_filename(game_name)}_{difficulty_level}.json\", \"r\", encoding=\"utf8\") as f:\n"," sid_prompt_dict = json.load(f)\n","\n"," correct_cnt = 0\n"," for sid, prompt in tqdm(list(sid_prompt_dict.items()), desc=f\"{game_filename(game_name)}_-_{difficulty_level}\"):\n"," cur_game = game_cls()\n"," cur_game.load_game(prompt)\n"," response = get_gemma_response(cur_game.get_prompt()).strip()\n"," solved, val_msg = cur_game.validate(response)\n"," with open(f\"model_outputs/results_gemma_2_9B_it.pkl\", \"ab\") as o:\n"," pickle.dump((f\"{game_filename(game_name)}_{difficulty_level}\", sid, response, (solved, val_msg)), o)\n"," if solved:\n"," correct_cnt += 1\n","\n"," print(f\"{game_name}_-_{difficulty_level}\")\n"," print(f\" Acc.: {correct_cnt / len(sid_prompt_dict):.2%}\")"],"metadata":{"id":"hCTXYpXa1UQ6"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"GC-zkVI52IJX"},"execution_count":null,"outputs":[]}]}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
agents/check_param.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from transformers import AutoModelForCausalLM
|
3 |
+
model_name = os.getenv('MODEL_NAME')
|
4 |
+
model = AutoModelForCausalLM.from_pretrained(
|
5 |
+
model_name,
|
6 |
+
device_map="auto",
|
7 |
+
torch_dtype="bfloat16",
|
8 |
+
)
|
9 |
+
print(model_name, sum(p.numel() for p in model.parameters()), model.num_parameters())
|
oauth_environ_google.sh
CHANGED
@@ -1 +1 @@
|
|
1 |
-
export $(cat
|
|
|
1 |
+
export $(cat ${ENVFILE} | xargs)
|
play_helper.py
CHANGED
@@ -7,6 +7,8 @@ import gradio as gr
|
|
7 |
import hashlib
|
8 |
from io import BytesIO
|
9 |
|
|
|
|
|
10 |
from textgames import GAME_NAMES as _GAME_NAMES, LEVEL_IDS, LEVELS, new_game, preload_game, game_filename
|
11 |
from textgames.islands.islands import Islands
|
12 |
from textgames.sudoku.sudoku import Sudoku
|
@@ -39,8 +41,10 @@ def declare_components(demo, greet, use_login_button=False):
|
|
39 |
logout_btn = gr.Button("Logout", link="/logout", variant='huggingface', size='sm', elem_id="btn-logout")
|
40 |
reset_sid_btn = gr.Button(interactive=False, visible=False, size='sm')
|
41 |
with gr.Column(scale=2):
|
42 |
-
solved_games_df = gr.DataFrame(
|
43 |
-
|
|
|
|
|
44 |
level_radio = gr.Radio(LEVELS, label="Level", elem_id="radio-level-name", visible=False)
|
45 |
game_radio = gr.Radio(GAME_NAMES, label="Game", elem_id="radio-game-name", visible=False)
|
46 |
new_game_btn = gr.Button("Start Game", elem_id="btn-start-game", visible=False)
|
@@ -69,7 +73,7 @@ def declare_components(demo, greet, use_login_button=False):
|
|
69 |
).then(
|
70 |
lambda: gr.update(interactive=False), None, [new_game_btn],
|
71 |
).then(
|
72 |
-
check_played_game, [solved_games,
|
73 |
).then(
|
74 |
lambda uid: ([gr.update(visible=True, interactive=True)] if uid else
|
75 |
[gr.update(visible=False, interactive=False)]) * 3,
|
@@ -643,12 +647,12 @@ def start_new_game(game_name, level, session_state_component, is_solved_componen
|
|
643 |
js=js_submit)
|
644 |
give_up_checkbox = gr.Checkbox(False, visible=False, interactive=False)
|
645 |
give_up_btn.click(
|
646 |
-
|
647 |
-
).then(
|
648 |
lambda x: x, [give_up_checkbox], [give_up_checkbox],
|
649 |
js="(x) => confirm('🥹 Give-up? 💸')"
|
650 |
-
).then(
|
651 |
-
|
652 |
)
|
653 |
|
654 |
def _forfeiting(confirmed, _solved_games):
|
@@ -657,7 +661,7 @@ def start_new_game(game_name, level, session_state_component, is_solved_componen
|
|
657 |
cur_game.finish_stats_(forfeit=True)
|
658 |
if level in LEVELS and level not in _solved_games[game_name]:
|
659 |
_solved_games[game_name].append(level)
|
660 |
-
upload_to_drive(fp_out)
|
661 |
return 0, _solved_games
|
662 |
return 1, _solved_games
|
663 |
give_up_checkbox.change(
|
@@ -696,7 +700,8 @@ def start_new_game(game_name, level, session_state_component, is_solved_componen
|
|
696 |
'game_name': game_name, 'difficulty_level': difficulty_level,
|
697 |
}, f)
|
698 |
f.write("\n")
|
699 |
-
|
|
|
700 |
upload_to_drive(_leaderboards, update=True)
|
701 |
return gr.update(interactive=True)
|
702 |
return gr.update()
|
@@ -715,13 +720,15 @@ def start_new_game(game_name, level, session_state_component, is_solved_componen
|
|
715 |
|
716 |
# %%
|
717 |
def check_to_start_new_game(game_name, level, user=None, uid=None, sid=None):
|
718 |
-
|
|
|
|
|
719 |
if game_name is None or level is None:
|
720 |
raise gr.Error("please choose both Game & Level")
|
721 |
fp = _get_file_output(game_name, LEVEL_IDS[LEVELS.index(level)], f"{uid}_{sid}")
|
722 |
if os.path.exists(fp):
|
723 |
# raise gr.Error(f"You have done this game already.<br/>{game_name} - {level}")
|
724 |
-
gr.Warning("You have done this game already
|
725 |
if user is None:
|
726 |
gr.Warning("no user, game will be generated randomly")
|
727 |
# else:
|
@@ -733,11 +740,11 @@ def check_to_start_new_game(game_name, level, user=None, uid=None, sid=None):
|
|
733 |
|
734 |
|
735 |
# %%
|
736 |
-
def check_played_game(solved_games,
|
737 |
uid = user['email']
|
738 |
sid = user.get('sid', None)
|
739 |
matches = _files.list(
|
740 |
-
q=f"'{_folder_id}' in parents and mimeType='application/octet-stream' and name contains '{uid}_-_'",
|
741 |
fields=f"files(name, id, {_cksm_methods_str})",
|
742 |
).execute()
|
743 |
matches = matches['files']
|
@@ -747,10 +754,14 @@ def check_played_game(solved_games, user, progress=gr.Progress()):
|
|
747 |
for level, level_id in zip(LEVELS, LEVEL_IDS):
|
748 |
fp_out = _get_file_output(game_name, level_id, f"{uid}_{sid}")
|
749 |
_matches = list(filter(lambda m: fp_out.endswith(m['name']), matches))
|
750 |
-
if os.path.exists(fp_out):
|
751 |
-
|
752 |
-
|
753 |
-
|
|
|
|
|
|
|
|
|
754 |
if os.path.exists(fp_out):
|
755 |
cur.append(level)
|
756 |
ret[game_name] = cur or '∅'
|
|
|
7 |
import hashlib
|
8 |
from io import BytesIO
|
9 |
|
10 |
+
from datetime import datetime
|
11 |
+
|
12 |
from textgames import GAME_NAMES as _GAME_NAMES, LEVEL_IDS, LEVELS, new_game, preload_game, game_filename
|
13 |
from textgames.islands.islands import Islands
|
14 |
from textgames.sudoku.sudoku import Sudoku
|
|
|
41 |
logout_btn = gr.Button("Logout", link="/logout", variant='huggingface', size='sm', elem_id="btn-logout")
|
42 |
reset_sid_btn = gr.Button(interactive=False, visible=False, size='sm')
|
43 |
with gr.Column(scale=2):
|
44 |
+
solved_games_df = gr.DataFrame(
|
45 |
+
pd.DataFrame({g.split('\t', 1)[0]: ['∅'] for g in GAME_NAMES}), label="Attempted Games",
|
46 |
+
row_count=(1, 'fixed'), col_count=(8, 'fixed'), interactive=False, elem_id="df-solved-games",
|
47 |
+
)
|
48 |
level_radio = gr.Radio(LEVELS, label="Level", elem_id="radio-level-name", visible=False)
|
49 |
game_radio = gr.Radio(GAME_NAMES, label="Game", elem_id="radio-game-name", visible=False)
|
50 |
new_game_btn = gr.Button("Start Game", elem_id="btn-start-game", visible=False)
|
|
|
73 |
).then(
|
74 |
lambda: gr.update(interactive=False), None, [new_game_btn],
|
75 |
).then(
|
76 |
+
check_played_game, [user_state, solved_games, solved_games_df], [solved_games, solved_games_df]
|
77 |
).then(
|
78 |
lambda uid: ([gr.update(visible=True, interactive=True)] if uid else
|
79 |
[gr.update(visible=False, interactive=False)]) * 3,
|
|
|
647 |
js=js_submit)
|
648 |
give_up_checkbox = gr.Checkbox(False, visible=False, interactive=False)
|
649 |
give_up_btn.click(
|
650 |
+
# lambda: (gr.update(interactive=False), gr.update(interactive=False)), None, [submit_btn, give_up_btn]
|
651 |
+
# ).then(
|
652 |
lambda x: x, [give_up_checkbox], [give_up_checkbox],
|
653 |
js="(x) => confirm('🥹 Give-up? 💸')"
|
654 |
+
# ).then(
|
655 |
+
# lambda: (gr.update(interactive=True), gr.update(interactive=True)), None, [submit_btn, give_up_btn]
|
656 |
)
|
657 |
|
658 |
def _forfeiting(confirmed, _solved_games):
|
|
|
661 |
cur_game.finish_stats_(forfeit=True)
|
662 |
if level in LEVELS and level not in _solved_games[game_name]:
|
663 |
_solved_games[game_name].append(level)
|
664 |
+
upload_to_drive(fp_out, update=True)
|
665 |
return 0, _solved_games
|
666 |
return 1, _solved_games
|
667 |
give_up_checkbox.change(
|
|
|
700 |
'game_name': game_name, 'difficulty_level': difficulty_level,
|
701 |
}, f)
|
702 |
f.write("\n")
|
703 |
+
print(f" >>> Solved @ {datetime.now()}:", uid, sid, game_name, level, sep=" ")
|
704 |
+
upload_to_drive(fp_out, update=True)
|
705 |
upload_to_drive(_leaderboards, update=True)
|
706 |
return gr.update(interactive=True)
|
707 |
return gr.update()
|
|
|
720 |
|
721 |
# %%
|
722 |
def check_to_start_new_game(game_name, level, user=None, uid=None, sid=None):
|
723 |
+
if not sid and isinstance(user, dict):
|
724 |
+
sid = user.get('sid', None)
|
725 |
+
print(f" >>> Starts @ {datetime.now()}:", uid, sid, game_name, level, sep=" ")
|
726 |
if game_name is None or level is None:
|
727 |
raise gr.Error("please choose both Game & Level")
|
728 |
fp = _get_file_output(game_name, LEVEL_IDS[LEVELS.index(level)], f"{uid}_{sid}")
|
729 |
if os.path.exists(fp):
|
730 |
# raise gr.Error(f"You have done this game already.<br/>{game_name} - {level}")
|
731 |
+
gr.Warning("You have done this game already.<br/>Only the first attempt is recorded on the leaderboard.")
|
732 |
if user is None:
|
733 |
gr.Warning("no user, game will be generated randomly")
|
734 |
# else:
|
|
|
740 |
|
741 |
|
742 |
# %%
|
743 |
+
def check_played_game(user, solved_games, solved_games_df, progress=gr.Progress()):
|
744 |
uid = user['email']
|
745 |
sid = user.get('sid', None)
|
746 |
matches = _files.list(
|
747 |
+
q=f"'{_folder_id}' in parents and mimeType='application/octet-stream' and name contains '{uid}_{sid}_-_'",
|
748 |
fields=f"files(name, id, {_cksm_methods_str})",
|
749 |
).execute()
|
750 |
matches = matches['files']
|
|
|
754 |
for level, level_id in zip(LEVELS, LEVEL_IDS):
|
755 |
fp_out = _get_file_output(game_name, level_id, f"{uid}_{sid}")
|
756 |
_matches = list(filter(lambda m: fp_out.endswith(m['name']), matches))
|
757 |
+
if _matches and not os.path.exists(fp_out):
|
758 |
+
os.system(f"touch \"{fp_out}\"")
|
759 |
+
elif not _matches and os.path.exists(fp_out):
|
760 |
+
upload_to_drive(fp_out, _matches, update=True)
|
761 |
+
# if os.path.exists(fp_out):
|
762 |
+
# upload_to_drive(fp_out, _matches, update=True)
|
763 |
+
# else:
|
764 |
+
# download_from_drive(fp_out, _matches)
|
765 |
if os.path.exists(fp_out):
|
766 |
cur.append(level)
|
767 |
ret[game_name] = cur or '∅'
|
play_with_hf.py
CHANGED
@@ -64,12 +64,12 @@ with gr.Blocks(title="TextGames", css=css, delete_cache=(3600, 3600)) as demo:
|
|
64 |
|
65 |
reset_sid_checkbox = gr.Checkbox(False, visible=False, interactive=False)
|
66 |
reset_sid_btn.click(
|
67 |
-
|
68 |
-
).then(
|
69 |
lambda x: x, [reset_sid_checkbox], [reset_sid_checkbox],
|
70 |
-
js="(x) => confirm('
|
71 |
-
).then(
|
72 |
-
|
73 |
)
|
74 |
|
75 |
def _resetting(confirmed, user):
|
@@ -78,13 +78,15 @@ with gr.Blocks(title="TextGames", css=css, delete_cache=(3600, 3600)) as demo:
|
|
78 |
gr.Warning("You need to log in first!")
|
79 |
elif confirmed:
|
80 |
user['sid'] = get_sid(uid, force_generate_sid=True)
|
|
|
81 |
return user, False
|
82 |
reset_sid_checkbox.change(
|
83 |
lambda: [gr.update(interactive=False)]*3, None, [logout_btn, reset_sid_btn, new_game_btn]
|
84 |
).then(
|
85 |
_resetting, [reset_sid_checkbox, user_state], [user_state, reset_sid_checkbox]
|
86 |
).then(
|
87 |
-
check_played_game, [solved_games,
|
|
|
88 |
).then(
|
89 |
lambda: [gr.update(interactive=True)]*3, None, [logout_btn, reset_sid_btn, new_game_btn]
|
90 |
)
|
@@ -96,30 +98,72 @@ with gr.Blocks(title="TextGames", css=css, delete_cache=(3600, 3600)) as demo:
|
|
96 |
start_new_game(game_name, level, session_state, is_solved, solved_games, user=user, uid=_uid_state)
|
97 |
|
98 |
#%%
|
99 |
-
with demo.route("Leaderboards", "/
|
100 |
-
gr.Markdown("Under Construction. Will be available soon.")
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
|
124 |
|
125 |
#%%
|
@@ -130,3 +174,18 @@ demo.launch(
|
|
130 |
)
|
131 |
|
132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
reset_sid_checkbox = gr.Checkbox(False, visible=False, interactive=False)
|
66 |
reset_sid_btn.click(
|
67 |
+
# lambda: [gr.update(interactive=False)]*2, None, [reset_sid_btn, new_game_btn]
|
68 |
+
# ).then(
|
69 |
lambda x: x, [reset_sid_checkbox], [reset_sid_checkbox],
|
70 |
+
js="(x) => confirm('Only your best session is recorded on the leaderboard. Are you sure you want to start from the beginning? (cannot be undone)')"
|
71 |
+
# ).then(
|
72 |
+
# lambda: [gr.update(interactive=True)]*2, None, [reset_sid_btn, new_game_btn]
|
73 |
)
|
74 |
|
75 |
def _resetting(confirmed, user):
|
|
|
78 |
gr.Warning("You need to log in first!")
|
79 |
elif confirmed:
|
80 |
user['sid'] = get_sid(uid, force_generate_sid=True)
|
81 |
+
gr.Info("Successfully resets the game with new session. Enjoy the game! 💪")
|
82 |
return user, False
|
83 |
reset_sid_checkbox.change(
|
84 |
lambda: [gr.update(interactive=False)]*3, None, [logout_btn, reset_sid_btn, new_game_btn]
|
85 |
).then(
|
86 |
_resetting, [reset_sid_checkbox, user_state], [user_state, reset_sid_checkbox]
|
87 |
).then(
|
88 |
+
check_played_game, [user_state, solved_games, solved_games_df], [solved_games, solved_games_df]
|
89 |
+
|
90 |
).then(
|
91 |
lambda: [gr.update(interactive=True)]*3, None, [logout_btn, reset_sid_btn, new_game_btn]
|
92 |
)
|
|
|
98 |
start_new_game(game_name, level, session_state, is_solved, solved_games, user=user, uid=_uid_state)
|
99 |
|
100 |
#%%
|
101 |
+
with (demo.route("Leaderboards", "/leaderboards") as demo_leaderboard):
|
102 |
+
# gr.Markdown("Under Construction. Will be available soon.")
|
103 |
+
def reload_leaderboard():
|
104 |
+
ret_leaderboards = {}
|
105 |
+
|
106 |
+
def add_dummies():
|
107 |
+
return pd.DataFrame({
|
108 |
+
'User': ['dummy'],
|
109 |
+
'Solved': [sorted([g.split('\t', 1)[0] for g in GAME_NAMES])],
|
110 |
+
'Attempts': [888],
|
111 |
+
'Time': [8888.8888],
|
112 |
+
})
|
113 |
+
|
114 |
+
if not os.path.exists(_leaderboards):
|
115 |
+
for lv in ['1', '2', '3']:
|
116 |
+
ret_leaderboards[lv] = add_dummies()
|
117 |
+
|
118 |
+
else:
|
119 |
+
datas = []
|
120 |
+
with open(_leaderboards, "r", encoding="utf8") as f:
|
121 |
+
for line in f:
|
122 |
+
datas.append(json.loads(line))
|
123 |
+
concat = [{'Level': d['difficulty_level'], 'User': d['uid'], 'Session': d['sid'],
|
124 |
+
'Solved': d['game_name'].split('\t', 1)[0], 'Attempts': d['turns'], "Time": d['ed'] - d['st']
|
125 |
+
} for d in datas]
|
126 |
+
df_leaderboards_all = pd.DataFrame(concat)
|
127 |
+
|
128 |
+
def get_best(_cur_df):
|
129 |
+
def _per_session(_df):
|
130 |
+
best = _df.groupby("Solved").apply(
|
131 |
+
lambda _df: _df.sort_values(["Attempts", "Time"]).iloc[0]
|
132 |
+
).reset_index(drop=True)
|
133 |
+
ret = pd.DataFrame({
|
134 |
+
"Solved": [sorted(best.Solved.unique())], "Attempts": best.Attempts.sum(), "Time": best.Time.sum(),
|
135 |
+
})
|
136 |
+
return ret
|
137 |
+
flat = _cur_df.groupby("Session").apply(_per_session)
|
138 |
+
srt = flat.sort_values(["Solved", "Attempts", "Time"], key=lambda c: {
|
139 |
+
"Solved": lambda s: -s.apply(len),
|
140 |
+
}.get(c.name, lambda s: s)(c))
|
141 |
+
return srt.iloc[0]
|
142 |
+
|
143 |
+
for lv in ['1', '2', '3']:
|
144 |
+
cur_df = df_leaderboards_all.loc[df_leaderboards_all.Level.eq(lv)].groupby("User").apply(get_best)
|
145 |
+
ret_leaderboards[lv] = cur_df.reset_index() if len(cur_df) else add_dummies()
|
146 |
+
|
147 |
+
return ret_leaderboards
|
148 |
+
|
149 |
+
df_leaderboards = {}
|
150 |
+
|
151 |
+
# for lv, tab_name in [('1', "🚅 Easy"), ('2', "🚀 Medium"), ('3', "🛸 Hard")]:
|
152 |
+
with gr.Tab("🚅 Easy") as tab1:
|
153 |
+
lb_df_1 = gr.DataFrame(label="Rankings", col_count=(4, 'fixed'), interactive=False, show_search='filter')
|
154 |
+
tab1.select(lambda: df_leaderboards['1'], None, [lb_df_1])
|
155 |
+
with gr.Tab("🚀 Medium") as tab2:
|
156 |
+
lb_df_2 = gr.DataFrame(label="Rankings", col_count=(4, 'fixed'), interactive=False, show_search='filter')
|
157 |
+
tab2.select(lambda: df_leaderboards['2'], None, [lb_df_2])
|
158 |
+
with gr.Tab("🛸 Hard") as tab3:
|
159 |
+
lb_df_3 = gr.DataFrame(label="Rankings", col_count=(4, 'fixed'), interactive=False, show_search='filter')
|
160 |
+
tab3.select(lambda: df_leaderboards['3'], None, [lb_df_3])
|
161 |
+
|
162 |
+
def onload(progress=gr.Progress()):
|
163 |
+
global df_leaderboards
|
164 |
+
df_leaderboards = reload_leaderboard()
|
165 |
+
return df_leaderboards['1']
|
166 |
+
demo_leaderboard.load(onload, None, [lb_df_1])
|
167 |
|
168 |
|
169 |
#%%
|
|
|
174 |
)
|
175 |
|
176 |
|
177 |
+
#%%
|
178 |
+
|
179 |
+
#%%
|
180 |
+
|
181 |
+
|
182 |
+
#%%
|
183 |
+
|
184 |
+
|
185 |
+
#%%
|
186 |
+
|
187 |
+
|
188 |
+
#%%
|
189 |
+
|
190 |
+
|
191 |
+
#%%
|
textgames_check_model_outputs.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# %%
|
2 |
+
import json
|
3 |
+
import pickle
|
4 |
+
import re
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
|
8 |
+
# %%
|
9 |
+
def load_pickle(fp):
|
10 |
+
with open(fp, "rb") as f:
|
11 |
+
try:
|
12 |
+
while True:
|
13 |
+
yield pickle.load(f)
|
14 |
+
except EOFError:
|
15 |
+
pass
|
16 |
+
|
17 |
+
|
18 |
+
# %%
|
19 |
+
fd = Path("model_outputs")
|
20 |
+
|
21 |
+
# %%
|
22 |
+
|
23 |
+
|
24 |
+
# %%
|
25 |
+
|
26 |
+
|
27 |
+
# %%
|
28 |
+
|
29 |
+
|
30 |
+
# # %%
|
31 |
+
# # concat pickle results (1/22)
|
32 |
+
# list(fd.glob("results_gemma_*"))[0]
|
33 |
+
#
|
34 |
+
# # %%
|
35 |
+
# fps = sorted(fd.glob("results_gemma_*"))
|
36 |
+
# all_responses = dict()
|
37 |
+
# errors = set()
|
38 |
+
# for fp in fps:
|
39 |
+
# responses = list(load_pickle(str(fp)))
|
40 |
+
# print(fp.name, len(responses), responses[0][0], responses[-1][0])
|
41 |
+
# for r in responses:
|
42 |
+
# if r[-1]:
|
43 |
+
# errors.add((r[0], str(r[-1])))
|
44 |
+
# all_responses.setdefault(r[:2], set())
|
45 |
+
# all_responses[r[:2]].add(r)
|
46 |
+
# errors = sorted(errors)
|
47 |
+
#
|
48 |
+
# # %%
|
49 |
+
# assert all(len(v) == 1 for v in all_responses.values()), f"Duplicated response(s) found"
|
50 |
+
#
|
51 |
+
# # %%
|
52 |
+
# duplicated = {k: v for k, v in all_responses.items() if len(v) > 1}
|
53 |
+
#
|
54 |
+
# # %%
|
55 |
+
# concatenated = [list(v)[0] for v in all_responses.values()]
|
56 |
+
#
|
57 |
+
# # %%
|
58 |
+
# with open(fd / "gemma2_9b_results_depre_250122/results_gemma-2-9b-it.single_turn.jsonl", "w", encoding="utf8") as o:
|
59 |
+
# for i in concatenated:
|
60 |
+
# json.dump({
|
61 |
+
# "game": i[0],
|
62 |
+
# "session": i[1],
|
63 |
+
# "turn": 1,
|
64 |
+
# "response": i[2],
|
65 |
+
# "solved": i[3][0],
|
66 |
+
# "val_msg": i[3][1],
|
67 |
+
# "error": repr(i[4]) if i[4] else i[4],
|
68 |
+
# }, o, ensure_ascii=False)
|
69 |
+
# o.write("\n")
|
70 |
+
|
71 |
+
# %%
|
72 |
+
|
73 |
+
|
74 |
+
# %%
|
75 |
+
|
76 |
+
|
77 |
+
# %%
|
78 |
+
|
79 |
+
|
80 |
+
# %%
|
81 |
+
|
82 |
+
|
83 |
+
# %%
|
84 |
+
|
85 |
+
|
86 |
+
# %%
|
87 |
+
# Rerun gemma, resolving errors
|
88 |
+
|
89 |
+
# %%
|
90 |
+
import os
|
91 |
+
import json
|
92 |
+
import pandas as pd
|
93 |
+
|
94 |
+
# %%
|
95 |
+
os.environ["TG_GAME_ST"] = "7"
|
96 |
+
os.environ["TG_GAME_ED"] = "8"
|
97 |
+
|
98 |
+
# %%
|
99 |
+
st, ed = os.getenv("TG_GAME_ST", None), os.getenv("TG_GAME_ED", None)
|
100 |
+
st, ed = ((None if x is None else int(x)) for x in (st, ed))
|
101 |
+
fp_out = f"model_outputs/results_gemma-2-9b-it{'' if st is None else f'.{st}'}.jsonl"
|
102 |
+
|
103 |
+
|
104 |
+
# %%
|
105 |
+
from tqdm import tqdm
|
106 |
+
from itertools import product
|
107 |
+
|
108 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
109 |
+
|
110 |
+
from textgames import THE_GAMES, GAME_NAMES, LEVEL_IDS, game_filename, _game_class_from_name
|
111 |
+
|
112 |
+
os.environ.setdefault("TEXTGAMES_OUTPUT_DIR", "user_outputs")
|
113 |
+
|
114 |
+
|
115 |
+
# %%
|
116 |
+
with open(fd / "gemma2_9b_results_depre_250122/results_gemma-2-9b-it.single_turn.jsonl", "r", encoding="utf-8") as f:
|
117 |
+
df = pd.read_json(f, lines=True)
|
118 |
+
|
119 |
+
# %%
|
120 |
+
df.columns
|
121 |
+
|
122 |
+
|
123 |
+
# %%
|
124 |
+
from agents import run_with_agent
|
125 |
+
from agents.gemma_2_9b_it import gemma_postproc
|
126 |
+
|
127 |
+
|
128 |
+
# %%
|
129 |
+
def get_buffered_response(texts, game_name, difficulty_level, turn):
|
130 |
+
if turn > 1:
|
131 |
+
return None
|
132 |
+
cur_df = df.loc[(df.game == f"{game_filename(game_name)}_{difficulty_level}")].set_index(["session", "turn"])
|
133 |
+
with open(f"problemsets/{game_filename(game_name)}_{difficulty_level}.json", "r", encoding="utf8") as f:
|
134 |
+
_sid_prompt_dict = json.load(f)
|
135 |
+
prompt_sid_dict = {v: k for k, v in _sid_prompt_dict.items()}
|
136 |
+
sid = prompt_sid_dict[texts[0]]
|
137 |
+
try:
|
138 |
+
return cur_df.loc[(sid, turn)].response
|
139 |
+
except KeyError:
|
140 |
+
return None
|
141 |
+
|
142 |
+
|
143 |
+
# %%
|
144 |
+
run_with_agent(fp_out, get_buffered_response, get_postprocess=gemma_postproc, game_names_list=GAME_NAMES[st:ed], n_turns=1)
|
145 |
+
|
146 |
+
# %%
|
147 |
+
|
148 |
+
|
149 |
+
# %%
|
150 |
+
# type(cur_df.loc[(sid, 1)].response)
|
151 |
+
|
152 |
+
# %%
|
153 |
+
|
154 |
+
|
155 |
+
# %%
|
156 |
+
|
157 |
+
|
158 |
+
# %%
|
159 |
+
|
160 |
+
|
161 |
+
# %%
|
162 |
+
|
163 |
+
|
164 |
+
# %%
|
165 |
+
|
166 |
+
|
167 |
+
# %%
|
168 |
+
|
169 |
+
|
170 |
+
# %%
|
171 |
+
|
172 |
+
|