{"cells":[{"cell_type":"code","execution_count":32,"metadata":{"executionInfo":{"elapsed":476,"status":"ok","timestamp":1720679526275,"user":{"displayName":"HUANG DONGHAO _","userId":"00977795705617022768"},"user_tz":-480},"id":"uWKRSV6eZsCn"},"outputs":[{"name":"stdout","output_type":"stream","text":["loading /Users/inflaton/code/engd/projects/logical-reasoning/llm_toolkit/logical_reasoning_utils.py\n","The autoreload extension is already loaded. To reload it, use:\n"," %reload_ext autoreload\n"]}],"source":["%load_ext autoreload\n","%autoreload 2"]},{"cell_type":"code","execution_count":33,"metadata":{"application/vnd.databricks.v1+cell":{"cellMetadata":{"byteLimit":2048000,"rowLimit":10000},"inputWidgets":{},"nuid":"eb33b19f-1206-41ee-84e2-e6258a12eef7","showTitle":false,"title":""},"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":2534,"status":"ok","timestamp":1720679529344,"user":{"displayName":"HUANG DONGHAO _","userId":"00977795705617022768"},"user_tz":-480},"id":"xwFh14uiZBrI","outputId":"d767799c-34c2-46a5-f052-378146a55321"},"outputs":[],"source":["from pathlib import Path\n","\n","if \"workding_dir\" not in locals():\n"," try:\n"," from google.colab import drive\n","\n"," drive.mount(\"/content/drive\")\n"," workding_dir = \"/content/drive/MyDrive/logical-reasoning/\"\n"," except ModuleNotFoundError:\n"," workding_dir = str(Path.cwd().parent)"]},{"cell_type":"code","execution_count":34,"metadata":{"application/vnd.databricks.v1+cell":{"cellMetadata":{"byteLimit":2048000,"rowLimit":10000},"inputWidgets":{},"nuid":"6d394937-6c99-4a7c-9d32-7600a280032f","showTitle":false,"title":""},"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":5,"status":"ok","timestamp":1720679529345,"user":{"displayName":"HUANG DONGHAO _","userId":"00977795705617022768"},"user_tz":-480},"id":"G5pNu3zgZBrL","outputId":"160a554f-fb08-4aa0-bc00-0422fb7c1fac"},"outputs":[{"name":"stdout","output_type":"stream","text":["workding dir: /Users/inflaton/code/engd/projects/logical-reasoning\n"]}],"source":["import os\n","import sys\n","from pathlib import Path\n","\n","os.chdir(workding_dir)\n","sys.path.append(workding_dir)\n","print(\"workding dir:\", workding_dir)"]},{"cell_type":"code","execution_count":35,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["working dir: /Users/inflaton/code/engd/projects/logical-reasoning\n"]}],"source":["# haotian comp\n","import os\n","import sys\n","from pathlib import Path\n","\n","if \"workding_dir\" not in locals():\n"," workding_dir = str(Path.cwd().parent)\n","os.chdir(workding_dir)\n","sys.path.append(workding_dir)\n","print(\"working dir:\", workding_dir)"]},{"cell_type":"code","execution_count":36,"metadata":{"application/vnd.databricks.v1+cell":{"cellMetadata":{"byteLimit":2048000,"rowLimit":10000},"inputWidgets":{},"nuid":"9f67ec60-2f24-411c-84eb-0dd664b44775","showTitle":false,"title":""},"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":3,"status":"ok","timestamp":1720679529345,"user":{"displayName":"HUANG DONGHAO _","userId":"00977795705617022768"},"user_tz":-480},"id":"hPCC-6m7ZBrM","outputId":"c7aa2c96-5e99-440a-c148-201d79465ff9"},"outputs":[{"name":"stdout","output_type":"stream","text":["loading env vars from: /Users/inflaton/code/engd/projects/logical-reasoning/.env\n"]},{"data":{"text/plain":["True"]},"execution_count":36,"metadata":{},"output_type":"execute_result"}],"source":["from dotenv import find_dotenv, load_dotenv\n","\n","found_dotenv = find_dotenv(\".env\")\n","\n","if len(found_dotenv) == 0:\n"," found_dotenv = find_dotenv(\".env.example\")\n","print(f\"loading env vars from: {found_dotenv}\")\n","load_dotenv(found_dotenv, override=True)"]},{"cell_type":"code","execution_count":37,"metadata":{"application/vnd.databricks.v1+cell":{"cellMetadata":{"byteLimit":2048000,"rowLimit":10000},"inputWidgets":{},"nuid":"f1597656-8042-4878-9d3b-9ebfb8dd86dc","showTitle":false,"title":""},"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":3,"status":"ok","timestamp":1720679529345,"user":{"displayName":"HUANG DONGHAO _","userId":"00977795705617022768"},"user_tz":-480},"id":"1M3IraVtZBrM","outputId":"29ab35f6-2970-4ade-d85d-3174acf8cda0"},"outputs":[],"source":["model_orders = {\n"," \"Mistral-7B-v0.3-Chinese-Chat\": 5,\n"," \"internlm2_5-7b-chat\": 9,\n"," \"internlm2_5-7b-chat-1m\": 10,\n"," \"Qwen2-7B-Instruct\": 20,\n"," \"Llama3.1-8B-Chinese-Chat\": 30,\n"," \"Llama3.1-70B-Chinese-Chat\": 40,\n"," \"Qwen2-72B-Instruct\": 50,\n","}"]},{"cell_type":"code","execution_count":38,"metadata":{},"outputs":[],"source":["markers = [\n"," \"o\",\n"," \"x\",\n"," \"^\",\n"," \"s\",\n"," \"d\",\n"," \"P\",\n"," \"X\",\n"," \"*\",\n"," \"v\",\n"," \">\",\n"," \"<\",\n"," \"p\",\n"," \"h\",\n"," \"H\",\n"," \"+\",\n"," \"|\",\n"," \"_\",\n","]\n","model_markers = {k: markers[i] for i, k in enumerate(model_orders.keys())}"]},{"cell_type":"code","execution_count":39,"metadata":{},"outputs":[{"data":{"text/html":["
\n"," | epoch | \n","model | \n","run | \n","accuracy | \n","precision | \n","recall | \n","f1 | \n","ratio_valid_classifications | \n","
---|---|---|---|---|---|---|---|---|
0 | \n","0.0 | \n","Mistral-7B-v0.3-Chinese-Chat | \n","shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat_torc... | \n","0.711333 | \n","0.702205 | \n","0.711333 | \n","0.689497 | \n","0.004 | \n","
1 | \n","0.2 | \n","Mistral-7B-v0.3-Chinese-Chat | \n","shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/chec... | \n","0.702000 | \n","0.793273 | \n","0.702000 | \n","0.734271 | \n","1.000 | \n","
2 | \n","0.4 | \n","Mistral-7B-v0.3-Chinese-Chat | \n","shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/chec... | \n","0.742000 | \n","0.789829 | \n","0.742000 | \n","0.753668 | \n","1.000 | \n","
3 | \n","0.6 | \n","Mistral-7B-v0.3-Chinese-Chat | \n","shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/chec... | \n","0.659667 | \n","0.792340 | \n","0.659667 | \n","0.706754 | \n","1.000 | \n","
4 | \n","0.8 | \n","Mistral-7B-v0.3-Chinese-Chat | \n","shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat/chec... | \n","0.714667 | \n","0.786134 | \n","0.714667 | \n","0.740468 | \n","1.000 | \n","
... | \n","... | \n","... | \n","... | \n","... | \n","... | \n","... | \n","... | \n","... | \n","
6 | \n","1.2 | \n","Qwen2-72B-Instruct | \n","Qwen/Qwen2-72B-Instruct/checkpoint-210_torch.b... | \n","0.763000 | \n","0.831888 | \n","0.763000 | \n","0.790108 | \n","1.000 | \n","
7 | \n","1.4 | \n","Qwen2-72B-Instruct | \n","Qwen/Qwen2-72B-Instruct/checkpoint-245_torch.b... | \n","0.765667 | \n","0.828827 | \n","0.765667 | \n","0.790627 | \n","1.000 | \n","
8 | \n","1.6 | \n","Qwen2-72B-Instruct | \n","Qwen/Qwen2-72B-Instruct/checkpoint-280_torch.b... | \n","0.769333 | \n","0.829280 | \n","0.769333 | \n","0.793017 | \n","1.000 | \n","
9 | \n","1.8 | \n","Qwen2-72B-Instruct | \n","Qwen/Qwen2-72B-Instruct/checkpoint-315_torch.b... | \n","0.784000 | \n","0.835435 | \n","0.784000 | \n","0.804195 | \n","1.000 | \n","
10 | \n","2.0 | \n","Qwen2-72B-Instruct | \n","Qwen/Qwen2-72B-Instruct/checkpoint-350_torch.b... | \n","0.773667 | \n","0.833015 | \n","0.773667 | \n","0.797366 | \n","1.000 | \n","
76 rows × 8 columns
\n","\n"," | index | \n","model | \n","run | \n","accuracy | \n","precision | \n","recall | \n","f1 | \n","ratio_valid_classifications | \n","
---|---|---|---|---|---|---|---|---|
0 | \n","1 | \n","internlm2_5-7b-chat | \n","internlm2_5-7b-chat | \n","0.749667 | \n","0.804187 | \n","0.749667 | \n","0.766016 | \n","1.000000 | \n","
1 | \n","2 | \n","internlm2_5-7b-chat-1m | \n","internlm2_5-7b-chat-1m | \n","0.803000 | \n","0.803141 | \n","0.803000 | \n","0.802806 | \n","1.000000 | \n","
2 | \n","3 | \n","Mistral-7B-v0.3-Chinese-Chat | \n","Mistral-7B-v0.3-Chinese-Chat | \n","0.750000 | \n","0.788587 | \n","0.750000 | \n","0.764823 | \n","1.000000 | \n","
3 | \n","4 | \n","Qwen2-7B-Instruct | \n","Qwen2-7B-Instruct | \n","0.759000 | \n","0.800530 | \n","0.759000 | \n","0.774875 | \n","1.000000 | \n","
4 | \n","5 | \n","Llama3.1-8B-Chinese-Chat | \n","Llama3.1-8B-Chinese-Chat | \n","0.780000 | \n","0.810583 | \n","0.780000 | \n","0.792465 | \n","1.000000 | \n","
5 | \n","6 | \n","Llama3.1-70B-Chinese-Chat | \n","Llama3.1-70B-Chinese-Chat | \n","0.796333 | \n","0.824897 | \n","0.796333 | \n","0.807687 | \n","1.000000 | \n","
6 | \n","7 | \n","Qwen2-72B-Instruct | \n","Qwen2-72B-Instruct | \n","0.784000 | \n","0.835435 | \n","0.784000 | \n","0.804195 | \n","1.000000 | \n","
7 | \n","8 | \n","Ensemble Model | \n","Ensemble Model | \n","0.819333 | \n","0.840746 | \n","0.819333 | \n","0.828054 | \n","1.000000 | \n","
8 | \n","9 | \n","gpt-4o-mini (10-shot) | \n","gpt-4o-mini (10-shot) | \n","0.679333 | \n","0.772809 | \n","0.679333 | \n","0.691675 | \n","0.999667 | \n","
9 | \n","10 | \n","o1-mini (10-shot) | \n","o1-mini (10-shot) | \n","0.725000 | \n","0.789249 | \n","0.725000 | \n","0.748562 | \n","1.000000 | \n","
10 | \n","11 | \n","gpt-4o (10-shot) | \n","gpt-4o (10-shot) | \n","0.791667 | \n","0.822771 | \n","0.791667 | \n","0.803615 | \n","0.999667 | \n","