Files changed (34) hide show
  1. .gitattributes +0 -12
  2. .gitignore +0 -3
  3. .ipynb_checkpoints/Untitled-checkpoint.ipynb +0 -0
  4. .ipynb_checkpoints/distinguish_high_low_label-checkpoint.ipynb +0 -447
  5. Untitled.ipynb +2 -2
  6. app.py +225 -709
  7. distinguish_high_low_label.ipynb +0 -553
  8. fullTest/test.txt +0 -3
  9. fullTest/test_info.txt +0 -3
  10. fullTest/test_label.txt +0 -0
  11. new_test_saved_finetuned_model.py +1 -6
  12. plot.png +0 -0
  13. ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/fullTest/highGRschool10_/test.txt +0 -3
  14. ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/fullTest/highGRschool10_/test_info.txt +0 -3
  15. ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/fullTest/highGRschool10_/test_label.txt +0 -0
  16. ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/fullTest/test.txt +0 -3
  17. ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/fullTest/test_BKT.txt +0 -3
  18. ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/fullTest/test_info.txt +0 -3
  19. ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/fullTest/test_label.txt +0 -0
  20. ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/highGRschool10/test_label.txt +0 -0
  21. ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/lowGRschoolAll/test.txt +0 -3
  22. ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/lowGRschoolAll/test_info.txt +0 -3
  23. ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/lowGRschoolAll/test_label.txt +0 -0
  24. ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/test.txt +0 -3
  25. ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/test_info.txt +0 -3
  26. ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/test_label.txt +0 -0
  27. result.txt +7 -7
  28. roc_data.pkl +2 -2
  29. roc_data2.pkl +0 -3
  30. selected_rows.txt +0 -0
  31. test.txt +0 -0
  32. train.txt +0 -0
  33. train_info.txt +0 -3
  34. train_label.txt +0 -0
.gitattributes CHANGED
@@ -38,15 +38,3 @@ ratio_proportion_change3/output/FS/bert_fine_tuned.model.ep32 filter=lfs diff=lf
38
  ratio_proportion_change3/output/IS/bert_fine_tuned.model.ep14 filter=lfs diff=lfs merge=lfs -text
39
  ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/highGRschool10/test_info.txt filter=lfs diff=lfs merge=lfs -text
40
  ratio_proportion_change3_2223/sch_largest_100-coded/output/highGRschool10/bert_fine_tuned.model.ep42 filter=lfs diff=lfs merge=lfs -text
41
- train_info.txt filter=lfs diff=lfs merge=lfs -text
42
- ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/lowGRschoolAll/test_info.txt filter=lfs diff=lfs merge=lfs -text
43
- ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/lowGRschoolAll/test.txt filter=lfs diff=lfs merge=lfs -text
44
- fullTest/test_info.txt filter=lfs diff=lfs merge=lfs -text
45
- fullTest/test.txt filter=lfs diff=lfs merge=lfs -text
46
- ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/test_info.txt filter=lfs diff=lfs merge=lfs -text
47
- ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/test.txt filter=lfs diff=lfs merge=lfs -text
48
- ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/fullTest/test_info.txt filter=lfs diff=lfs merge=lfs -text
49
- ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/fullTest/test.txt filter=lfs diff=lfs merge=lfs -text
50
- ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/fullTest/highGRschool10_/test_info.txt filter=lfs diff=lfs merge=lfs -text
51
- ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/fullTest/highGRschool10_/test.txt filter=lfs diff=lfs merge=lfs -text
52
- ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/fullTest/test_BKT.txt filter=lfs diff=lfs merge=lfs -text
 
38
  ratio_proportion_change3/output/IS/bert_fine_tuned.model.ep14 filter=lfs diff=lfs merge=lfs -text
39
  ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/highGRschool10/test_info.txt filter=lfs diff=lfs merge=lfs -text
40
  ratio_proportion_change3_2223/sch_largest_100-coded/output/highGRschool10/bert_fine_tuned.model.ep42 filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore CHANGED
@@ -1,5 +1,2 @@
1
  train_info.txt
2
- train.txt
3
- train_label.txt
4
  ratio_proportion_change3_2223/sch_largest_100-coded/logs/
5
- ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/
 
1
  train_info.txt
 
 
2
  ratio_proportion_change3_2223/sch_largest_100-coded/logs/
 
.ipynb_checkpoints/Untitled-checkpoint.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
.ipynb_checkpoints/distinguish_high_low_label-checkpoint.ipynb DELETED
@@ -1,447 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 3,
6
- "id": "960bac80-51c7-4e9f-ad2d-84cd6c710f98",
7
- "metadata": {},
8
- "outputs": [],
9
- "source": [
10
- "import pickle\n",
11
- "import pandas as pd"
12
- ]
13
- },
14
- {
15
- "cell_type": "code",
16
- "execution_count": 4,
17
- "id": "a34f21d0-0854-4a54-8f93-67718b2f969e",
18
- "metadata": {},
19
- "outputs": [],
20
- "source": [
21
- "file_path = \"roc_data2.pkl\"\n",
22
- "\n",
23
- "# Open and load the pickle file\n",
24
- "with open(file_path, 'rb') as file:\n",
25
- " data = pickle.load(file)\n",
26
- "\n",
27
- "\n",
28
- "# Print or use the data\n",
29
- "# data[2]"
30
- ]
31
- },
32
- {
33
- "cell_type": "code",
34
- "execution_count": 5,
35
- "id": "f9febed4-ce50-4e30-96ea-4b538ce2f9a1",
36
- "metadata": {},
37
- "outputs": [],
38
- "source": [
39
- "inc_slider=1\n",
40
- "parent_location=\"ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/\"\n",
41
- "test_info_location=parent_location+\"fullTest/test_info.txt\"\n",
42
- "test_location=parent_location+\"fullTest/test.txt\"\n",
43
- "test_info = pd.read_csv(test_info_location, sep=',', header=None, engine='python')\n",
44
- "grad_rate_data = pd.DataFrame(pd.read_pickle('school_grduation_rate.pkl'),columns=['school_number','grad_rate']) # Load the grad_rate data\n",
45
- "\n",
46
- "# Step 1: Extract unique school numbers from test_info\n",
47
- "unique_schools = test_info[0].unique()\n",
48
- "\n",
49
- "# Step 2: Filter the grad_rate_data using the unique school numbers\n",
50
- "schools = grad_rate_data[grad_rate_data['school_number'].isin(unique_schools)]\n",
51
- "\n",
52
- "# Define a threshold for high and low graduation rates (adjust as needed)\n",
53
- "grad_rate_threshold = 0.9 \n",
54
- "\n",
55
- "# Step 4: Divide schools into high and low graduation rate groups\n",
56
- "high_grad_schools = schools[schools['grad_rate'] >= grad_rate_threshold]['school_number'].unique()\n",
57
- "low_grad_schools = schools[schools['grad_rate'] < grad_rate_threshold]['school_number'].unique()\n",
58
- "\n",
59
- "# Step 5: Sample percentage of schools from each group\n",
60
- "high_sample = pd.Series(high_grad_schools).sample(frac=inc_slider/100, random_state=1).tolist()\n",
61
- "low_sample = pd.Series(low_grad_schools).sample(frac=inc_slider/100, random_state=1).tolist()\n",
62
- "\n",
63
- "# Step 6: Combine the sampled schools\n",
64
- "random_schools = high_sample + low_sample\n",
65
- "\n",
66
- "# Step 7: Get indices for the sampled schools\n",
67
- "indices = test_info[test_info[0].isin(random_schools)].index.tolist()\n",
68
- "\n"
69
- ]
70
- },
71
- {
72
- "cell_type": "code",
73
- "execution_count": 6,
74
- "id": "fdfdf4b6-2752-4a21-9880-869af69f20cf",
75
- "metadata": {},
76
- "outputs": [],
77
- "source": [
78
- "high_indices = test_info[(test_info[0].isin(high_sample))].index.tolist()\n",
79
- "low_indices = test_info[(test_info[0].isin(low_sample))].index.tolist()"
80
- ]
81
- },
82
- {
83
- "cell_type": "code",
84
- "execution_count": 7,
85
- "id": "a79a4598-5702-4cc8-9f07-8e18fdda648b",
86
- "metadata": {},
87
- "outputs": [
88
- {
89
- "data": {
90
- "text/plain": [
91
- "997"
92
- ]
93
- },
94
- "execution_count": 7,
95
- "metadata": {},
96
- "output_type": "execute_result"
97
- }
98
- ],
99
- "source": [
100
- "len(high_indices)+len(low_indices)\n"
101
- ]
102
- },
103
- {
104
- "cell_type": "code",
105
- "execution_count": 8,
106
- "id": "4707f3e6-2f44-46d8-ad8c-b6c244f693af",
107
- "metadata": {},
108
- "outputs": [
109
- {
110
- "data": {
111
- "text/html": [
112
- "<div>\n",
113
- "<style scoped>\n",
114
- " .dataframe tbody tr th:only-of-type {\n",
115
- " vertical-align: middle;\n",
116
- " }\n",
117
- "\n",
118
- " .dataframe tbody tr th {\n",
119
- " vertical-align: top;\n",
120
- " }\n",
121
- "\n",
122
- " .dataframe thead th {\n",
123
- " text-align: right;\n",
124
- " }\n",
125
- "</style>\n",
126
- "<table border=\"1\" class=\"dataframe\">\n",
127
- " <thead>\n",
128
- " <tr style=\"text-align: right;\">\n",
129
- " <th></th>\n",
130
- " <th>0</th>\n",
131
- " </tr>\n",
132
- " </thead>\n",
133
- " <tbody>\n",
134
- " <tr>\n",
135
- " <th>5342</th>\n",
136
- " <td>PercentChange-0\\tNumeratorQuantity1-0\\tNumerat...</td>\n",
137
- " </tr>\n",
138
- " <tr>\n",
139
- " <th>5343</th>\n",
140
- " <td>PercentChange-0\\tNumeratorQuantity2-0\\tNumerat...</td>\n",
141
- " </tr>\n",
142
- " <tr>\n",
143
- " <th>5344</th>\n",
144
- " <td>PercentChange-0\\tNumeratorQuantity2-0\\tNumerat...</td>\n",
145
- " </tr>\n",
146
- " <tr>\n",
147
- " <th>5345</th>\n",
148
- " <td>PercentChange-0\\tNumeratorQuantity2-2\\tNumerat...</td>\n",
149
- " </tr>\n",
150
- " <tr>\n",
151
- " <th>5346</th>\n",
152
- " <td>PercentChange-0\\tNumeratorQuantity2-0\\tDenomin...</td>\n",
153
- " </tr>\n",
154
- " <tr>\n",
155
- " <th>...</th>\n",
156
- " <td>...</td>\n",
157
- " </tr>\n",
158
- " <tr>\n",
159
- " <th>113359</th>\n",
160
- " <td>PercentChange-0\\tNumeratorQuantity2-2\\tNumerat...</td>\n",
161
- " </tr>\n",
162
- " <tr>\n",
163
- " <th>113360</th>\n",
164
- " <td>PercentChange-0\\tNumeratorQuantity2-0\\tNumerat...</td>\n",
165
- " </tr>\n",
166
- " <tr>\n",
167
- " <th>113361</th>\n",
168
- " <td>PercentChange-0\\tNumeratorQuantity2-0\\tNumerat...</td>\n",
169
- " </tr>\n",
170
- " <tr>\n",
171
- " <th>113362</th>\n",
172
- " <td>PercentChange-0\\tNumeratorQuantity2-0\\tNumerat...</td>\n",
173
- " </tr>\n",
174
- " <tr>\n",
175
- " <th>113363</th>\n",
176
- " <td>PercentChange-0\\tNumeratorQuantity2-0\\tNumerat...</td>\n",
177
- " </tr>\n",
178
- " </tbody>\n",
179
- "</table>\n",
180
- "<p>997 rows × 1 columns</p>\n",
181
- "</div>"
182
- ],
183
- "text/plain": [
184
- " 0\n",
185
- "5342 PercentChange-0\\tNumeratorQuantity1-0\\tNumerat...\n",
186
- "5343 PercentChange-0\\tNumeratorQuantity2-0\\tNumerat...\n",
187
- "5344 PercentChange-0\\tNumeratorQuantity2-0\\tNumerat...\n",
188
- "5345 PercentChange-0\\tNumeratorQuantity2-2\\tNumerat...\n",
189
- "5346 PercentChange-0\\tNumeratorQuantity2-0\\tDenomin...\n",
190
- "... ...\n",
191
- "113359 PercentChange-0\\tNumeratorQuantity2-2\\tNumerat...\n",
192
- "113360 PercentChange-0\\tNumeratorQuantity2-0\\tNumerat...\n",
193
- "113361 PercentChange-0\\tNumeratorQuantity2-0\\tNumerat...\n",
194
- "113362 PercentChange-0\\tNumeratorQuantity2-0\\tNumerat...\n",
195
- "113363 PercentChange-0\\tNumeratorQuantity2-0\\tNumerat...\n",
196
- "\n",
197
- "[997 rows x 1 columns]"
198
- ]
199
- },
200
- "execution_count": 8,
201
- "metadata": {},
202
- "output_type": "execute_result"
203
- }
204
- ],
205
- "source": [
206
- "# Load the test file and select rows based on indices\n",
207
- "test = pd.read_csv(test_location, sep=',', header=None, engine='python')\n",
208
- "selected_rows_df2 = test.loc[indices]\n",
209
- "selected_rows_df2"
210
- ]
211
- },
212
- {
213
- "cell_type": "code",
214
- "execution_count": 11,
215
- "id": "1d0c3d49-061f-486b-9c19-cf20945f3207",
216
- "metadata": {},
217
- "outputs": [],
218
- "source": [
219
- "graduation_groups = [\n",
220
- " 'high' if idx in high_indices else 'low' for idx in selected_rows_df2.index\n",
221
- "]\n",
222
- "# graduation_groups"
223
- ]
224
- },
225
- {
226
- "cell_type": "code",
227
- "execution_count": 43,
228
- "id": "ad0ce4a1-27fa-4867-8061-4054dbb340df",
229
- "metadata": {},
230
- "outputs": [],
231
- "source": [
232
- "t_label=data[0]\n",
233
- "p_label=data[1]"
234
- ]
235
- },
236
- {
237
- "cell_type": "code",
238
- "execution_count": 47,
239
- "id": "a4f4a2b9-3134-42ac-871b-4e117098cd0e",
240
- "metadata": {},
241
- "outputs": [],
242
- "source": [
243
- "# Step 1: Align graduation_group, t_label, and p_label\n",
244
- "aligned_labels = list(zip(graduation_groups, t_label, p_label))\n",
245
- "\n",
246
- "# Step 2: Separate the labels for high and low groups\n",
247
- "high_t_labels = [t for grad, t, p in aligned_labels if grad == 'high']\n",
248
- "low_t_labels = [t for grad, t, p in aligned_labels if grad == 'low']\n",
249
- "\n",
250
- "high_p_labels = [p for grad, t, p in aligned_labels if grad == 'high']\n",
251
- "low_p_labels = [p for grad, t, p in aligned_labels if grad == 'low']\n",
252
- "\n"
253
- ]
254
- },
255
- {
256
- "cell_type": "code",
257
- "execution_count": 50,
258
- "id": "c8e34660-83d0-46a1-a218-95d609e11729",
259
- "metadata": {},
260
- "outputs": [
261
- {
262
- "data": {
263
- "text/plain": [
264
- "997"
265
- ]
266
- },
267
- "execution_count": 50,
268
- "metadata": {},
269
- "output_type": "execute_result"
270
- }
271
- ],
272
- "source": [
273
- "len(low_t_labels)+len(high_t_labels)"
274
- ]
275
- },
276
- {
277
- "cell_type": "code",
278
- "execution_count": 51,
279
- "id": "c11050db-2636-4c50-9cd4-b9943e5cee83",
280
- "metadata": {},
281
- "outputs": [],
282
- "source": [
283
- "from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, roc_curve, roc_auc_score"
284
- ]
285
- },
286
- {
287
- "cell_type": "code",
288
- "execution_count": 52,
289
- "id": "e1309e93-7063-4f48-bbc7-11a0d449c34e",
290
- "metadata": {},
291
- "outputs": [
292
- {
293
- "name": "stdout",
294
- "output_type": "stream",
295
- "text": [
296
- "ROC-AUC Score for High Graduation Rate Group: 0.675\n",
297
- "ROC-AUC Score for Low Graduation Rate Group: 0.7489795918367347\n"
298
- ]
299
- }
300
- ],
301
- "source": [
302
- "high_roc_auc = roc_auc_score(high_t_labels, high_p_labels) if len(set(high_t_labels)) > 1 else None\n",
303
- "low_roc_auc = roc_auc_score(low_t_labels, low_p_labels) if len(set(low_t_labels)) > 1 else None\n",
304
- "\n",
305
- "print(\"ROC-AUC Score for High Graduation Rate Group:\", high_roc_auc)\n",
306
- "print(\"ROC-AUC Score for Low Graduation Rate Group:\", low_roc_auc)"
307
- ]
308
- },
309
- {
310
- "cell_type": "code",
311
- "execution_count": 4,
312
- "id": "a99e7812-817d-4f9f-b6fa-1a58aa3a34dc",
313
- "metadata": {},
314
- "outputs": [
315
- {
316
- "ename": "TypeError",
317
- "evalue": "cannot convert the series to <class 'int'>",
318
- "output_type": "error",
319
- "traceback": [
320
- "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
321
- "\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)",
322
- "Cell \u001b[1;32mIn[4], line 47\u001b[0m\n\u001b[0;32m 44\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mopen\u001b[39m(test_info_location, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m file:\n\u001b[0;32m 45\u001b[0m data \u001b[38;5;241m=\u001b[39m file\u001b[38;5;241m.\u001b[39mreadlines()\n\u001b[1;32m---> 47\u001b[0m ideal_opt_task \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mint\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mtest_info\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m7\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# Assuming test_info[7] is accessible and holds the ideal task (1 or 2)\u001b[39;00m\n\u001b[0;32m 49\u001b[0m \u001b[38;5;66;03m# Initialize counters\u001b[39;00m\n\u001b[0;32m 50\u001b[0m task_counts \u001b[38;5;241m=\u001b[39m {\n\u001b[0;32m 51\u001b[0m \u001b[38;5;241m1\u001b[39m: {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124monly_opt1\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m0\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124monly_opt2\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m0\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mboth\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m0\u001b[39m},\n\u001b[0;32m 52\u001b[0m \u001b[38;5;241m2\u001b[39m: {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124monly_opt1\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m0\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124monly_opt2\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m0\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mboth\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m0\u001b[39m}\n\u001b[0;32m 53\u001b[0m }\n",
323
- "File \u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\pandas\\core\\series.py:230\u001b[0m, in \u001b[0;36m_coerce_method.<locals>.wrapper\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 222\u001b[0m warnings\u001b[38;5;241m.\u001b[39mwarn(\n\u001b[0;32m 223\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCalling \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mconverter\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m on a single element Series is \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 224\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdeprecated and will raise a TypeError in the future. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 227\u001b[0m stacklevel\u001b[38;5;241m=\u001b[39mfind_stack_level(),\n\u001b[0;32m 228\u001b[0m )\n\u001b[0;32m 229\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m converter(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39miloc[\u001b[38;5;241m0\u001b[39m])\n\u001b[1;32m--> 230\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcannot convert the series to \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mconverter\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
324
- "\u001b[1;31mTypeError\u001b[0m: cannot convert the series to <class 'int'>"
325
- ]
326
- }
327
- ],
328
- "source": [
329
- "parent_location=\"ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/\"\n",
330
- "test_info_location=parent_location+\"fullTest/test_info.txt\"\n",
331
- "test_location=parent_location+\"fullTest/test.txt\"\n",
332
- "test_info = pd.read_csv(test_info_location, sep=',', header=None, engine='python')\n",
333
- "\n",
334
- "def analyze_row(row, ideal_opt_task):\n",
335
- " # Split the row into fields\n",
336
- " fields = row.split(\"\\t\")\n",
337
- "\n",
338
- " # Define tasks for OptionalTask_1, OptionalTask_2, and FinalAnswer\n",
339
- " optional_task_1_subtasks = [\"DenominatorFactor\", \"NumeratorFactor\", \"EquationAnswer\"]\n",
340
- " optional_task_2_subtasks = [\n",
341
- " \"FirstRow2:1\", \"FirstRow2:2\", \"FirstRow1:1\", \"FirstRow1:2\", \n",
342
- " \"SecondRow\", \"ThirdRow\"\n",
343
- " ]\n",
344
- " final_answer_tasks = [\"FinalAnswer\"]\n",
345
- "\n",
346
- " # Helper function to evaluate task attempts\n",
347
- " def evaluate_tasks(fields, tasks):\n",
348
- " task_status = {}\n",
349
- " for task in tasks:\n",
350
- " relevant_attempts = [f for f in fields if task in f]\n",
351
- " if any(\"OK\" in attempt for attempt in relevant_attempts):\n",
352
- " task_status[task] = \"Attempted (Successful)\"\n",
353
- " elif any(\"ERROR\" in attempt for attempt in relevant_attempts):\n",
354
- " task_status[task] = \"Attempted (Error)\"\n",
355
- " elif any(\"JIT\" in attempt for attempt in relevant_attempts):\n",
356
- " task_status[task] = \"Attempted (JIT)\"\n",
357
- " else:\n",
358
- " task_status[task] = \"Unattempted\"\n",
359
- " return task_status\n",
360
- "\n",
361
- " # Evaluate tasks for each category\n",
362
- " optional_task_1_status = evaluate_tasks(fields, optional_task_1_subtasks)\n",
363
- " optional_task_2_status = evaluate_tasks(fields, optional_task_2_subtasks)\n",
364
- "\n",
365
- " # Check if tasks have any successful attempt\n",
366
- " opt1_done = any(status == \"Attempted (Successful)\" for status in optional_task_1_status.values())\n",
367
- " opt2_done = any(status == \"Attempted (Successful)\" for status in optional_task_2_status.values())\n",
368
- "\n",
369
- " return opt1_done, opt2_done\n",
370
- "\n",
371
- "# Read data from test_info.txt\n",
372
- "with open(test_info_location, \"r\") as file:\n",
373
- " data = file.readlines()\n",
374
- "\n",
375
- "ideal_opt_task = int(test_info[6]) # Assuming test_info[7] is accessible and holds the ideal task (1 or 2)\n",
376
- "\n",
377
- "# Initialize counters\n",
378
- "task_counts = {\n",
379
- " 1: {\"only_opt1\": 0, \"only_opt2\": 0, \"both\": 0},\n",
380
- " 2: {\"only_opt1\": 0, \"only_opt2\": 0, \"both\": 0}\n",
381
- "}\n",
382
- "\n",
383
- "for row in data:\n",
384
- " row = row.strip()\n",
385
- " if not row:\n",
386
- " continue\n",
387
- " opt1_done, opt2_done = analyze_row(row, ideal_opt_task)\n",
388
- "\n",
389
- " if ideal_opt_task == 0:\n",
390
- " if opt1_done and not opt2_done:\n",
391
- " task_counts[1][\"only_opt1\"] += 1\n",
392
- " elif not opt1_done and opt2_done:\n",
393
- " task_counts[1][\"only_opt2\"] += 1\n",
394
- " elif opt1_done and opt2_done:\n",
395
- " task_counts[1][\"both\"] += 1\n",
396
- " elif ideal_opt_task == 1:\n",
397
- " if opt1_done and not opt2_done:\n",
398
- " task_counts[2][\"only_opt1\"] += 1\n",
399
- " elif not opt1_done and opt2_done:\n",
400
- " task_counts[2][\"only_opt2\"] += 1\n",
401
- " elif opt1_done and opt2_done:\n",
402
- " task_counts[2][\"both\"] += 1\n",
403
- "\n",
404
- "# Create a string output for results\n",
405
- "output_summary = \"Task Analysis Summary:\\n\"\n",
406
- "output_summary += \"-----------------------\\n\"\n",
407
- "\n",
408
- "for ideal_task, counts in task_counts.items():\n",
409
- " output_summary += f\"Ideal Task = OptionalTask_{ideal_task}:\\n\"\n",
410
- " output_summary += f\" Only OptionalTask_1 done: {counts['only_opt1']}\\n\"\n",
411
- " output_summary += f\" Only OptionalTask_2 done: {counts['only_opt2']}\\n\"\n",
412
- " output_summary += f\" Both done: {counts['both']}\\n\"\n",
413
- "\n",
414
- "print(output_summary)"
415
- ]
416
- },
417
- {
418
- "cell_type": "code",
419
- "execution_count": null,
420
- "id": "65ad9383-741f-44eb-8e8f-853ee7bc52a2",
421
- "metadata": {},
422
- "outputs": [],
423
- "source": []
424
- }
425
- ],
426
- "metadata": {
427
- "kernelspec": {
428
- "display_name": "Python 3 (ipykernel)",
429
- "language": "python",
430
- "name": "python3"
431
- },
432
- "language_info": {
433
- "codemirror_mode": {
434
- "name": "ipython",
435
- "version": 3
436
- },
437
- "file_extension": ".py",
438
- "mimetype": "text/x-python",
439
- "name": "python",
440
- "nbconvert_exporter": "python",
441
- "pygments_lexer": "ipython3",
442
- "version": "3.12.4"
443
- }
444
- },
445
- "nbformat": 4,
446
- "nbformat_minor": 5
447
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Untitled.ipynb CHANGED
@@ -623,7 +623,7 @@
623
  "uri": "us-docker.pkg.dev/deeplearning-platform-release/gcr.io/base-cu113:m122"
624
  },
625
  "kernelspec": {
626
- "display_name": "Python 3 (ipykernel)",
627
  "language": "python",
628
  "name": "python3"
629
  },
@@ -637,7 +637,7 @@
637
  "name": "python",
638
  "nbconvert_exporter": "python",
639
  "pygments_lexer": "ipython3",
640
- "version": "3.12.4"
641
  }
642
  },
643
  "nbformat": 4,
 
623
  "uri": "us-docker.pkg.dev/deeplearning-platform-release/gcr.io/base-cu113:m122"
624
  },
625
  "kernelspec": {
626
+ "display_name": "Python 3",
627
  "language": "python",
628
  "name": "python3"
629
  },
 
637
  "name": "python",
638
  "nbconvert_exporter": "python",
639
  "pygments_lexer": "ipython3",
640
+ "version": "3.10.14"
641
  }
642
  },
643
  "nbformat": 4,
app.py CHANGED
@@ -8,43 +8,24 @@ import shutil
8
  import matplotlib.pyplot as plt
9
  from sklearn.metrics import roc_curve, auc
10
  import pandas as pd
11
- import plotly.graph_objects as go
12
- from sklearn.metrics import roc_auc_score
13
- from matplotlib.figure import Figure
14
  # Define the function to process the input file and model selection
15
 
16
- def process_file(model_name,inc_slider,progress=Progress(track_tqdm=True)):
17
  # progress = gr.Progress(track_tqdm=True)
18
-
19
  progress(0, desc="Starting the processing")
20
- # with open(file.name, 'r') as f:
21
- # content = f.read()
22
- # saved_test_dataset = "train.txt"
23
- # saved_test_label = "train_label.txt"
24
- # saved_train_info="train_info.txt"
25
  # Save the uploaded file content to a specified location
26
- # shutil.copyfile(file.name, saved_test_dataset)
27
- # shutil.copyfile(label.name, saved_test_label)
28
- # shutil.copyfile(info.name, saved_train_info)
29
- parent_location="ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/"
30
- test_info_location=parent_location+"fullTest/test_info.txt"
31
- test_location=parent_location+"fullTest/test.txt"
32
- if(model_name=="ASTRA-FT-HGR"):
33
- finetune_task="highGRschool10"
34
- # test_info_location=parent_location+"fullTest/test_info.txt"
35
- # test_location=parent_location+"fullTest/test.txt"
36
- elif(model_name== "ASTRA-FT-LGR" ):
37
- finetune_task="lowGRschoolAll"
38
- # test_info_location=parent_location+"lowGRschoolAll/test_info.txt"
39
- # test_location=parent_location+"lowGRschoolAll/test.txt"
40
- elif(model_name=="ASTRA-FT-FULL"):
41
- # test_info_location=parent_location+"fullTest/test_info.txt"
42
- # test_location=parent_location+"fullTest/test.txt"
43
- finetune_task="fullTest"
44
- else:
45
- finetune_task=None
46
  # Load the test_info file and the graduation rate file
47
- test_info = pd.read_csv(test_info_location, sep=',', header=None, engine='python')
48
  grad_rate_data = pd.DataFrame(pd.read_pickle('school_grduation_rate.pkl'),columns=['school_number','grad_rate']) # Load the grad_rate data
49
 
50
  # Step 1: Extract unique school numbers from test_info
@@ -69,50 +50,24 @@ def process_file(model_name,inc_slider,progress=Progress(track_tqdm=True)):
69
 
70
  # Step 7: Get indices for the sampled schools
71
  indices = test_info[test_info[0].isin(random_schools)].index.tolist()
72
- high_indices = test_info[(test_info[0].isin(high_sample))].index.tolist()
73
- low_indices = test_info[(test_info[0].isin(low_sample))].index.tolist()
74
-
75
  # Load the test file and select rows based on indices
76
- test = pd.read_csv(test_location, sep=',', header=None, engine='python')
77
  selected_rows_df2 = test.loc[indices]
78
 
79
  # Save the selected rows to a file
80
  selected_rows_df2.to_csv('selected_rows.txt', sep='\t', index=False, header=False, quoting=3, escapechar=' ')
81
 
82
- graduation_groups = [
83
- 'high' if idx in high_indices else 'low' for idx in selected_rows_df2.index
84
- ]
85
- # Group data by opt_task1 and opt_task2 based on test_info[6]
86
- opt_task_groups = ['opt_task1' if test_info.loc[idx, 6] == 0 else 'opt_task2' for idx in selected_rows_df2.index]
87
-
88
- with open("roc_data2.pkl", 'rb') as file:
89
- data = pickle.load(file)
90
- t_label=data[0]
91
- p_label=data[1]
92
- # Step 1: Align graduation_group, t_label, and p_label
93
- aligned_labels = list(zip(graduation_groups, t_label, p_label))
94
- opt_task_aligned = list(zip(opt_task_groups, t_label, p_label))
95
- # Step 2: Separate the labels for high and low groups
96
- high_t_labels = [t for grad, t, p in aligned_labels if grad == 'high']
97
- low_t_labels = [t for grad, t, p in aligned_labels if grad == 'low']
98
-
99
- high_p_labels = [p for grad, t, p in aligned_labels if grad == 'high']
100
- low_p_labels = [p for grad, t, p in aligned_labels if grad == 'low']
101
-
102
- opt_task1_t_labels = [t for task, t, p in opt_task_aligned if task == 'opt_task1']
103
- opt_task1_p_labels = [p for task, t, p in opt_task_aligned if task == 'opt_task1']
104
-
105
- opt_task2_t_labels = [t for task, t, p in opt_task_aligned if task == 'opt_task2']
106
- opt_task2_p_labels = [p for task, t, p in opt_task_aligned if task == 'opt_task2']
107
-
108
- high_roc_auc = roc_auc_score(high_t_labels, high_p_labels) if len(set(high_t_labels)) > 1 else None
109
- low_roc_auc = roc_auc_score(low_t_labels, low_p_labels) if len(set(low_t_labels)) > 1 else None
110
-
111
- opt_task1_roc_auc = roc_auc_score(opt_task1_t_labels, opt_task1_p_labels) if len(set(opt_task1_t_labels)) > 1 else None
112
- opt_task2_roc_auc = roc_auc_score(opt_task2_t_labels, opt_task2_p_labels) if len(set(opt_task2_t_labels)) > 1 else None
113
-
114
  # For demonstration purposes, we'll just return the content with the selected model name
115
-
 
 
 
 
 
 
 
116
  # print(checkpoint)
117
  progress(0.1, desc="Files created and saved")
118
  # if (inc_val<5):
@@ -121,189 +76,11 @@ def process_file(model_name,inc_slider,progress=Progress(track_tqdm=True)):
121
  # model_name="highGRschool10"
122
  # else:
123
  # model_name="highGRschool10"
124
- # Function to analyze each row
125
- def analyze_row(row):
126
- # Split the row into fields
127
- fields = row.split("\t")
128
-
129
- # Define tasks for OptionalTask_1, OptionalTask_2, and FinalAnswer
130
- optional_task_1_subtasks = ["DenominatorFactor", "NumeratorFactor", "EquationAnswer"]
131
- optional_task_2_subtasks = [
132
- "FirstRow2:1", "FirstRow2:2", "FirstRow1:1", "FirstRow1:2",
133
- "SecondRow", "ThirdRow"
134
- ]
135
-
136
- # Helper function to evaluate task attempts
137
- def evaluate_tasks(fields, tasks):
138
- task_status = {}
139
- for task in tasks:
140
- relevant_attempts = [f for f in fields if task in f]
141
- if any("OK" in attempt for attempt in relevant_attempts):
142
- task_status[task] = "Attempted (Successful)"
143
- elif any("ERROR" in attempt for attempt in relevant_attempts):
144
- task_status[task] = "Attempted (Error)"
145
- elif any("JIT" in attempt for attempt in relevant_attempts):
146
- task_status[task] = "Attempted (JIT)"
147
- else:
148
- task_status[task] = "Unattempted"
149
- return task_status
150
-
151
- # Evaluate tasks for each category
152
- optional_task_1_status = evaluate_tasks(fields, optional_task_1_subtasks)
153
- optional_task_2_status = evaluate_tasks(fields, optional_task_2_subtasks)
154
-
155
- # Check if tasks have any successful attempt
156
- opt1_done = any(status == "Attempted (Successful)" for status in optional_task_1_status.values())
157
- opt2_done = any(status == "Attempted (Successful)" for status in optional_task_2_status.values())
158
-
159
- return opt1_done, opt2_done
160
-
161
- # Read data from test_info.txt
162
- with open(test_info_location, "r") as file:
163
- data = file.readlines()
164
-
165
- # Assuming test_info[7] is a list with ideal tasks for each instance
166
- ideal_tasks = test_info[6] # A list where each element is either 1 or 2
167
-
168
- # Initialize counters
169
- task_counts = {
170
- 1: {"ER": 0, "ME": 0, "both": 0,"none":0},
171
- 2: {"ER": 0, "ME": 0, "both": 0,"none":0}
172
- }
173
-
174
- # Analyze rows
175
- for i, row in enumerate(data):
176
- row = row.strip()
177
- if not row:
178
- continue
179
-
180
- ideal_task = ideal_tasks[i] # Get the ideal task for the current row
181
- opt1_done, opt2_done = analyze_row(row)
182
-
183
- if ideal_task == 0:
184
- if opt1_done and not opt2_done:
185
- task_counts[1]["ER"] += 1
186
- elif not opt1_done and opt2_done:
187
- task_counts[1]["ME"] += 1
188
- elif opt1_done and opt2_done:
189
- task_counts[1]["both"] += 1
190
- else:
191
- task_counts[1]["none"] +=1
192
- elif ideal_task == 1:
193
- if opt1_done and not opt2_done:
194
- task_counts[2]["ER"] += 1
195
- elif not opt1_done and opt2_done:
196
- task_counts[2]["ME"] += 1
197
- elif opt1_done and opt2_done:
198
- task_counts[2]["both"] += 1
199
- else:
200
- task_counts[2]["none"] +=1
201
-
202
- # Create a string output for results
203
- # output_summary = "Task Analysis Summary:\n"
204
- # output_summary += "-----------------------\n"
205
-
206
- # for ideal_task, counts in task_counts.items():
207
- # output_summary += f"Ideal Task = OptionalTask_{ideal_task}:\n"
208
- # output_summary += f" Only OptionalTask_1 done: {counts['ER']}\n"
209
- # output_summary += f" Only OptionalTask_2 done: {counts['ME']}\n"
210
- # output_summary += f" Both done: {counts['both']}\n"
211
-
212
- # colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']
213
- colors = ["#FF6F61", "#6B5B95", "#88B04B", "#F7CAC9"]
214
-
215
- # Generate pie chart for Task 1
216
- task1_labels = list(task_counts[1].keys())
217
- task1_values = list(task_counts[1].values())
218
-
219
- # fig_task1 = Figure()
220
- # ax1 = fig_task1.add_subplot(1, 1, 1)
221
- # ax1.pie(task1_values, labels=task1_labels, autopct='%1.1f%%', startangle=90)
222
- # ax1.set_title('Ideal Task 1 Distribution')
223
-
224
- fig_task1 = go.Figure(data=[go.Pie(
225
- labels=task1_labels,
226
- values=task1_values,
227
- textinfo='percent+label',
228
- textposition='auto',
229
- marker=dict(colors=colors),
230
- sort=False
231
-
232
- )])
233
-
234
- fig_task1.update_layout(
235
- title='Problem Type: ER',
236
- title_x=0.5,
237
- font=dict(
238
- family="sans-serif",
239
- size=12,
240
- color="black"
241
- ),
242
- )
243
-
244
- fig_task1.update_layout(
245
- legend=dict(
246
- font=dict(
247
- family="sans-serif",
248
- size=12,
249
- color="black"
250
- ),
251
- )
252
- )
253
-
254
-
255
-
256
- # fig.show()
257
-
258
- # Generate pie chart for Task 2
259
- task2_labels = list(task_counts[2].keys())
260
- task2_values = list(task_counts[2].values())
261
-
262
- fig_task2 = go.Figure(data=[go.Pie(
263
- labels=task2_labels,
264
- values=task2_values,
265
- textinfo='percent+label',
266
- textposition='auto',
267
- marker=dict(colors=colors),
268
- sort=False
269
- # pull=[0, 0.2, 0, 0] # for pulling part of pie chart out (depends on position)
270
-
271
- )])
272
-
273
- fig_task2.update_layout(
274
- title='Problem Type: ME',
275
- title_x=0.5,
276
- font=dict(
277
- family="sans-serif",
278
- size=12,
279
- color="black"
280
- ),
281
- )
282
-
283
- fig_task2.update_layout(
284
- legend=dict(
285
- font=dict(
286
- family="sans-serif",
287
- size=12,
288
- color="black"
289
- ),
290
- )
291
- )
292
-
293
-
294
- # fig_task2 = Figure()
295
- # ax2 = fig_task2.add_subplot(1, 1, 1)
296
- # ax2.pie(task2_values, labels=task2_labels, autopct='%1.1f%%', startangle=90)
297
- # ax2.set_title('Ideal Task 2 Distribution')
298
-
299
- # print(output_summary)
300
-
301
- progress(0.2, desc="analysis done!! Executing models")
302
- print("finetuned task: ",finetune_task)
303
  subprocess.run([
304
  "python", "new_test_saved_finetuned_model.py",
305
  "-workspace_name", "ratio_proportion_change3_2223/sch_largest_100-coded",
306
- "-finetune_task", finetune_task,
307
  "-test_dataset_path","../../../../selected_rows.txt",
308
  # "-test_label_path","../../../../train_label.txt",
309
  "-finetuned_bert_classifier_checkpoint",
@@ -321,510 +98,249 @@ def process_file(model_name,inc_slider,progress=Progress(track_tqdm=True)):
321
  result[key]=value
322
  else:
323
  result[key]=float(value)
324
- result["ROC score of HGR"]=high_roc_auc
325
- result["ROC score of LGR"]=low_roc_auc
326
  # Create a plot
327
  with open("roc_data.pkl", "rb") as f:
328
  fpr, tpr, _ = pickle.load(f)
329
- # print(fpr,tpr)
330
- roc_auc = auc(fpr, tpr)
331
-
332
-
333
- # Create a matplotlib figure
334
- # fig = Figure()
335
- # ax = fig.add_subplot(1, 1, 1)
336
- # ax.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
337
- # ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
338
- # ax.set(xlabel='False Positive Rate', ylabel='True Positive Rate', title=f'Receiver Operating Curve (ROC)')
339
- # ax.legend(loc="lower right")
340
- # ax.grid()
341
-
342
- fig = go.Figure()
343
- # Create and style traces
344
- fig.add_trace(go.Line(x = list(fpr), y = list(tpr), name=f'ROC curve (area = {roc_auc:.2f})',
345
- line=dict(color='royalblue', width=3,
346
- ) # dash options include 'dash', 'dot', and 'dashdot'
347
- ))
348
- fig.add_trace(go.Line(x = [0,1], y = [0,1], showlegend = False,
349
- line=dict(color='firebrick', width=2,
350
- dash='dash',) # dash options include 'dash', 'dot', and 'dashdot'
351
- ))
352
-
353
- # Edit the layout
354
- fig.update_layout(
355
- showlegend = True,
356
- title_x=0.5,
357
- title=dict(
358
- text='Receiver Operating Curve (ROC)'
359
- ),
360
- xaxis=dict(
361
- title=dict(
362
- text='False Positive Rate'
363
- )
364
- ),
365
- yaxis=dict(
366
- title=dict(
367
- text='False Negative Rate'
368
- )
369
- ),
370
- font=dict(
371
- family="sans-serif",
372
- color="black"
373
- ),
374
-
375
- )
376
- fig.update_layout(
377
- legend=dict(
378
- x=0.75,
379
- y=0,
380
- traceorder="normal",
381
- font=dict(
382
- family="sans-serif",
383
- size=12,
384
- color="black"
385
- ),
386
- )
387
- )
388
-
389
-
390
-
391
-
392
 
 
 
 
 
 
 
 
393
 
394
  # Save plot to a file
395
- # plot_path = "plot.png"
396
- # fig.savefig(plot_path)
397
- # plt.close(fig)
398
-
399
-
400
-
401
-
402
  progress(1.0)
403
  # Prepare text output
404
  text_output = f"Model: {model_name}\nResult:\n{result}"
405
  # Prepare text output with HTML formatting
406
  text_output = f"""
407
- ---------------------------
408
- Model: {model_name}
409
- ---------------------------\n
410
- Time Taken: {result['time_taken_from_start']:.2f} seconds
411
- Total Schools in test: {len(unique_schools):.4f}
412
- Total number of instances having Schools with HGR : {len(high_sample):.4f}
413
- Total number of instances having Schools with LGR: {len(low_sample):.4f}
414
-
415
- ROC score of HGR: {high_roc_auc:.4f}
416
- ROC score of LGR: {low_roc_auc:.4f}
417
-
418
- ROC-AUC for problems of type ER: {opt_task1_roc_auc:.4f}
419
- ROC-AUC for problems of type ME: {opt_task2_roc_auc:.4f}
420
  """
421
- return text_output,fig,fig_task1,fig_task2
422
 
423
  # List of models for the dropdown menu
424
 
425
- # models = ["ASTRA-FT-HGR", "ASTRA-FT-LGR", "ASTRA-FT-FULL"]
426
- models = ["ASTRA-FT-HGR", "ASTRA-FT-FULL"]
427
- content = """
428
- <h1 style="color: black;">A S T R A</h1>
429
- <h2 style="color: black;">An AI Model for Analyzing Math Strategies</h2>
430
-
431
- <h3 style="color: white; text-align: center">
432
- <a href="https://drive.google.com/file/d/1lbEpg8Se1ugTtkjreD8eXIg7qrplhWan/view" style="color: gr.themes.colors.red; text-decoration: none;">Link To Paper</a> |
433
- <a href="https://github.com/Syudu41/ASTRA---Gates-Project" style="color: #1E90FF; text-decoration: none;">GitHub</a> |
434
- <a href="https://sites.google.com/view/astra-research/home" style="color: #1E90FF; text-decoration: none;">Project Page</a>
435
- </h3>
436
 
437
- <p style="color: white;">Welcome to a demo of ASTRA. ASTRA is a collaborative research project between researchers at the
438
- <a href="https://sites.google.com/site/dvngopal/" style="color: #1E90FF; text-decoration: none;">University of Memphis</a> and
439
- <a href="https://www.carnegielearning.com" style="color: #1E90FF; text-decoration: none;">Carnegie Learning</a>
440
- to utilize AI to improve our understanding of math learning strategies.</p>
441
-
442
- <p style="color: white;">This demo has been developed with a pre-trained model (based on an architecture similar to BERT ) that learns math strategies using data
443
- collected from hundreds of schools in the U.S. who have used Carnegie Learning’s MATHia (formerly known as Cognitive Tutor), the flagship Intelligent Tutor that is part of a core, blended math curriculum.
444
- For this demo, we have used data from a specific domain (teaching ratio and proportions) within 7th grade math. The fine-tuning based on the pre-trained model learns to predict which strategies lead to correct vs incorrect solutions.
445
- </p>
446
-
447
- <p style="color: white;">In this math domain, students were given word problems related to ratio and proportions. Further, the students
448
- were given a choice of optional tasks to work on in parallel to the main problem to demonstrate their thinking (metacognition).
449
- The optional tasks are designed based on solving problems using Equivalent Ratios (ER) and solving using Means and Extremes/cross-multiplication (ME).
450
- When the equivalent ratios are easy to compute (integral values), ER is much more efficient compared to ME and switching between the tasks appropriately demonstrates cognitive flexibility.
451
- </p>
452
-
453
- <p style="color: white;">To use the demo, please follow these steps:</p>
454
-
455
- <ol style="color: white;">
456
- <li style="color: white;">Select a fine-tuned model:
457
- <ul style="color: white;">
458
- <li style="color: white;">ASTRA-FT-HGR: Fine-tuned with a small sample of data from schools that have a high graduation rate.</li>
459
- <li style="color: white;">ASTRA-FT-Full: Fine-tuned with a small sample of data from a mix of schools that have high/low graduation rates.</li>
460
- </ul>
461
- </li>
462
- <li style="color: white;">Select a percentage of schools to analyze (selecting a large percentage may take a long time). Note that the selected percentage is applied to both High Graduation Rate (HGR) schools and Low Graduation Rate (LGR schools).
463
- </li>
464
- <li style="color: white;">The results from the fine-tuned model are displayed in the dashboard:
465
- <ul>
466
- <li style="color: white;">The model accuracy is computed using the ROC-AUC metric.
467
- </li>
468
- <li style="color: white;">The results are shown for HGR, LGR schools and for different problem types (ER/ME).
469
- </li>
470
- <li style="color: white;">The distribution over how students utilized the optional tasks (whether they utilized ER/ME, used both of them or none of them) is shown for each problem type.
471
- </li>
472
- </ul>
473
- </li>
474
- </ol>
475
- """
476
- # CSS styling for white text
477
  # Create the Gradio interface
478
- available_themes = {
479
- "default": gr.themes.Default(),
480
- "soft": gr.themes.Soft(),
481
- "monochrome": gr.themes.Monochrome(),
482
- "glass": gr.themes.Glass(),
483
- "base": gr.themes.Base(),
484
- }
485
-
486
- # Comprehensive CSS for all HTML elements
487
- custom_css = '''
488
- /* Import Fira Sans font */
489
- @import url('https://fonts.googleapis.com/css2?family=Fira+Sans:wght@400;500;600;700&family=Inter:wght@400;500;600;700&display=swap');
490
- @import url('https://fonts.googleapis.com/css2?family=Libre+Caslon+Text:ital,wght@0,400;0,700;1,400&family=Spectral+SC:wght@600&display=swap');
491
- /* Container modifications for centering */
492
- .gradio-container {
493
- color: var(--block-label-text-color) !important;
494
- max-width: 1000px !important;
495
- margin: 0 auto !important;
496
- padding: 2rem !important;
497
- font-family: Arial, sans-serif !important;
498
- }
499
-
500
- /* Main title (ASTRA) */
501
- #title {
502
- text-align: center !important;
503
- margin: 1rem auto !important; /* Reduced margin */
504
- font-size: 2.5em !important;
505
- font-weight: 600 !important;
506
- font-family: "Spectral SC", 'Fira Sans', sans-serif !important;
507
- padding-bottom: 0 !important; /* Remove bottom padding */
508
- }
509
-
510
- /* Subtitle (An AI Model...) */
511
- h1 {
512
- text-align: center !important;
513
- font-size: 30pt !important;
514
- font-weight: 600 !important;
515
- font-family: "Spectral SC", 'Fira Sans', sans-serif !important;
516
- margin-top: 0.5em !important; /* Reduced top margin */
517
- margin-bottom: 0.3em !important;
518
- }
519
-
520
- h2 {
521
- text-align: center !important;
522
- font-size: 22pt !important;
523
- font-weight: 600 !important;
524
- font-family: "Spectral SC",'Fira Sans', sans-serif !important;
525
- margin-top: 0.2em !important; /* Reduced top margin */
526
- margin-bottom: 0.3em !important;
527
- }
528
-
529
- /* Links container styling */
530
- .links-container {
531
- text-align: center !important;
532
- margin: 1em auto !important;
533
- font-family: 'Inter' ,'Fira Sans', sans-serif !important;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
534
  }
535
-
536
- /* Links */
537
- a {
538
- color: #2563eb !important;
539
- text-decoration: none !important;
540
- font-family:'Inter' , 'Fira Sans', sans-serif !important;
541
  }
542
 
543
- a:hover {
544
- text-decoration: underline !important;
545
- opacity: 0.8;
 
 
 
 
 
 
 
546
  }
547
 
548
- /* Regular text */
549
- p, li, .description, .markdown-text {
550
- font-family: 'Inter', Arial, sans-serif !important;
551
- color: black !important;
552
- font-size: 11pt;
553
- line-height: 1.6;
554
- font-weight: 500 !important;
555
- color: var(--block-label-text-color) !important;
 
 
556
  }
557
 
558
- /* Other headings */
559
- h3, h4, h5 {
560
- font-family: 'Fira Sans', sans-serif !important;
561
- color: var(--block-label-text-color) !important;
562
- margin-top: 1.5em;
563
- margin-bottom: 0.75em;
564
  }
565
-
566
-
567
- h3 { font-size: 1.5em; font-weight: 600; }
568
- h4 { font-size: 1.25em; font-weight: 500; }
569
- h5 { font-size: 1.1em; font-weight: 500; }
570
-
571
- /* Form elements */
572
- .select-wrap select, .wrap select,
573
- input, textarea {
574
- font-family: 'Inter' ,Arial, sans-serif !important;
575
- color: var(--block-label-text-color) !important;
576
  }
577
-
578
- /* Lists */
579
- ul, ol {
580
- margin-left: 0 !important;
581
- margin-bottom: 1.25em;
582
- padding-left: 2em;
583
  }
584
 
585
- li {
586
- margin-bottom: 0.75em;
587
- }
588
-
589
- /* Form container */
590
- .form-container {
591
- max-width: 1000px !important;
592
- margin: 0 auto !important;
593
- padding: 1rem !important;
594
- }
595
-
596
- /* Dashboard */
597
- .dashboard {
598
- margin-top: 2rem !important;
599
- padding: 1rem !important;
600
- border-radius: 8px !important;
601
  }
602
-
603
- /* Slider styling */
604
- .gradio-slider-row {
605
  display: flex;
 
 
606
  align-items: center;
607
- justify-content: space-between;
608
- margin: 1.5em 0;
609
- max-width: 100% !important;
610
- }
611
-
612
- .gradio-slider {
613
- flex-grow: 1;
614
- margin-right: 15px;
615
- }
616
-
617
- .slider-percentage {
618
- font-family: 'Inter', Arial, sans-serif !important;
619
- flex-shrink: 0;
620
- min-width: 60px;
621
- font-size: 1em;
622
- font-weight: bold;
623
  text-align: center;
624
- background-color: #f0f8ff;
625
- border: 1px solid #004080;
626
- border-radius: 5px;
627
- padding: 5px 10px;
628
- }
629
-
630
- .progress-bar-wrap.progress-bar-wrap.progress-bar-wrap
631
- {
632
- border-radius: var(--input-radius);
633
- height: 1.25rem;
634
- margin-top: 1rem;
635
- overflow: hidden;
636
- width: 70%;
637
- font-family: 'Inter', Arial, sans-serif !important;
638
- }
639
-
640
- /* Add these new styles after your existing CSS */
641
-
642
- /* Card-like appearance for the dashboard */
643
- .dashboard {
644
- background: #ffffff !important;
645
- box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06) !important;
646
- border-radius: 12px !important;
647
- padding: 2rem !important;
648
- margin-top: 2.5rem !important;
649
  }
650
-
651
- /* Enhance ROC graph container */
652
- #roc {
653
- background: #ffffff !important;
654
- padding: 1.5rem !important;
655
- border-radius: 8px !important;
656
- box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05) !important;
657
- margin: 1.5rem 0 !important;
658
  }
659
-
660
- /* Style the dropdown select */
661
- select {
662
- background-color: #ffffff !important;
663
- border: 1px solid #e2e8f0 !important;
664
- border-radius: 8px !important;
665
- padding: 0.5rem 1rem !important;
666
- transition: all 0.2s ease-in-out !important;
667
- box-shadow: 0 1px 2px rgba(0, 0, 0, 0.05) !important;
668
- }
669
-
670
- select:hover {
671
- border-color: #cbd5e1 !important;
672
- box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1) !important;
673
  }
674
-
675
- /* Enhance slider appearance */
676
- .progress-bar-wrap {
677
- background: #f8fafc !important;
678
- border: 1px solid #e2e8f0 !important;
679
- box-shadow: inset 0 2px 4px rgba(0, 0, 0, 0.05) !important;
680
- }
681
-
682
- /* Style metrics in dashboard */
683
- .dashboard p {
684
- padding: 0.5rem 0 !important;
685
- border-bottom: 1px solid #f1f5f9 !important;
686
- }
687
-
688
- /* Add spacing between sections */
689
- .dashboard > div {
690
- margin-bottom: 1.5rem !important;
691
- }
692
-
693
- /* Style the ROC curve title */
694
- .dashboard h4 {
695
- color: #1e293b !important;
696
- font-weight: 600 !important;
697
- margin-bottom: 1rem !important;
698
- padding-bottom: 0.5rem !important;
699
- border-bottom: 2px solid #e2e8f0 !important;
700
- }
701
-
702
- /* Enhance link appearances */
703
- a {
704
- position: relative !important;
705
- padding-bottom: 2px !important;
706
- transition: all 0.2s ease-in-out !important;
707
- }
708
-
709
- a:after {
710
- content: '' !important;
711
- position: absolute !important;
712
- width: 0 !important;
713
- height: 1px !important;
714
- bottom: 0 !important;
715
- left: 0 !important;
716
- background-color: #2563eb !important;
717
- transition: width 0.3s ease-in-out !important;
718
- }
719
-
720
- a:hover:after {
721
- width: 100% !important;
722
- }
723
-
724
- /* Add subtle dividers between sections */
725
- .form-container > div {
726
- padding-bottom: 1.5rem !important;
727
- margin-bottom: 1.5rem !important;
728
- border-bottom: 1px solid #f1f5f9 !important;
729
- }
730
-
731
- /* Style model selection section */
732
- .select-wrap {
733
- background: #ffffff !important;
734
- padding: 1.5rem !important;
735
- border-radius: 8px !important;
736
- box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05) !important;
737
- margin-bottom: 2rem !important;
738
- }
739
-
740
- /* Style the metrics display */
741
- .dashboard span {
742
- font-family: 'Inter', sans-serif !important;
743
- font-weight: 500 !important;
744
- color: #334155 !important;
745
- }
746
-
747
- /* Add subtle animation to interactive elements */
748
- button, select, .slider-percentage {
749
- transition: all 0.2s ease-in-out !important;
750
- }
751
-
752
- /* Style the ROC curve container */
753
- .plot-container {
754
- background: #ffffff !important;
755
- border-radius: 8px !important;
756
- padding: 1rem !important;
757
- box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05) !important;
758
- }
759
-
760
- /* Add container styles for opt1 and opt2 sections */
761
- #opt1, #opt2 {
762
- background: #ffffff !important;
763
- border-radius: 8px !important;
764
- padding: 1.5rem !important;
765
- margin-top: 1.5rem !important;
766
- box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05) !important;
767
- }
768
-
769
- /* Style the distribution titles */
770
- .distribution-title {
771
- font-family: 'Inter', sans-serif !important;
772
- font-weight: 600 !important;
773
- color: #1e293b !important;
774
- margin-bottom: 1rem !important;
775
- text-align: center !important;
776
- }
777
-
778
- '''
779
-
780
- with gr.Blocks(theme='gstaff/sketch', css=custom_css) as demo:
781
-
782
- # gr.Markdown("<h1 id='title'>ASTRA</h1>", elem_id="title")
783
- gr.Markdown(content)
784
 
785
  with gr.Row():
786
- # file_input = gr.File(label="Upload a test file", file_types=['.txt'], elem_classes="file-box")
787
- # label_input = gr.File(label="Upload test labels", file_types=['.txt'], elem_classes="file-box")
788
 
789
- # info_input = gr.File(label="Upload test info", file_types=['.txt'], elem_classes="file-box")
790
- model_dropdown = gr.Dropdown(
791
- choices=models,
792
- label="Select Fine-tuned Model",
793
- elem_classes="dropdown-menu"
794
- )
795
- increment_slider = gr.Slider(
796
- minimum=1,
797
- maximum=100,
798
- step=1,
799
- label="Schools Percentage",
800
- value=1,
801
- elem_id="increment-slider",
802
- elem_classes="gradio-slider"
803
- )
804
 
805
- with gr.Row():
806
- btn = gr.Button("Submit")
807
-
808
- gr.Markdown("<p class='description'>Dashboard</p>")
809
-
810
- with gr.Row():
811
- output_text = gr.Textbox(label="")
812
- # output_image = gr.Image(label="ROC")
813
- with gr.Row():
814
- plot_output = gr.Plot(label="ROC")
815
 
 
 
 
816
  with gr.Row():
817
- opt1_pie = gr.Plot(label="ER")
818
- opt2_pie = gr.Plot(label="ME")
819
- # output_summary = gr.Textbox(label="Summary")
820
 
821
-
822
 
823
- btn.click(
824
- fn=process_file,
825
- inputs=[model_dropdown,increment_slider],
826
- outputs=[output_text,plot_output,opt1_pie,opt2_pie]
827
- )
828
 
829
 
830
  # Launch the app
 
8
  import matplotlib.pyplot as plt
9
  from sklearn.metrics import roc_curve, auc
10
  import pandas as pd
 
 
 
11
  # Define the function to process the input file and model selection
12
 
13
+ def process_file(file,label,info,model_name,inc_slider,progress=Progress(track_tqdm=True)):
14
  # progress = gr.Progress(track_tqdm=True)
 
15
  progress(0, desc="Starting the processing")
16
+ with open(file.name, 'r') as f:
17
+ content = f.read()
18
+ saved_test_dataset = "train.txt"
19
+ saved_test_label = "train_label.txt"
20
+ saved_train_info="train_info.txt"
21
  # Save the uploaded file content to a specified location
22
+ shutil.copyfile(file.name, saved_test_dataset)
23
+ shutil.copyfile(label.name, saved_test_label)
24
+ shutil.copyfile(info.name, saved_train_info)
25
+
26
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  # Load the test_info file and the graduation rate file
28
+ test_info = pd.read_csv('train_info.txt', sep=',', header=None, engine='python')
29
  grad_rate_data = pd.DataFrame(pd.read_pickle('school_grduation_rate.pkl'),columns=['school_number','grad_rate']) # Load the grad_rate data
30
 
31
  # Step 1: Extract unique school numbers from test_info
 
50
 
51
  # Step 7: Get indices for the sampled schools
52
  indices = test_info[test_info[0].isin(random_schools)].index.tolist()
53
+
 
 
54
  # Load the test file and select rows based on indices
55
+ test = pd.read_csv('train.txt', sep=',', header=None, engine='python')
56
  selected_rows_df2 = test.loc[indices]
57
 
58
  # Save the selected rows to a file
59
  selected_rows_df2.to_csv('selected_rows.txt', sep='\t', index=False, header=False, quoting=3, escapechar=' ')
60
 
61
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  # For demonstration purposes, we'll just return the content with the selected model name
63
+ if(model_name=="High Graduated Schools"):
64
+ finetune_task="highGRschool10"
65
+ elif(model_name== "Low Graduated Schools" ):
66
+ finetune_task="highGRschool10"
67
+ elif(model_name=="Full Set"):
68
+ finetune_task="highGRschool10"
69
+ else:
70
+ finetune_task=None
71
  # print(checkpoint)
72
  progress(0.1, desc="Files created and saved")
73
  # if (inc_val<5):
 
76
  # model_name="highGRschool10"
77
  # else:
78
  # model_name="highGRschool10"
79
+ progress(0.2, desc="Executing models")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  subprocess.run([
81
  "python", "new_test_saved_finetuned_model.py",
82
  "-workspace_name", "ratio_proportion_change3_2223/sch_largest_100-coded",
83
+ "-finetune_task", "highGRschool10",
84
  "-test_dataset_path","../../../../selected_rows.txt",
85
  # "-test_label_path","../../../../train_label.txt",
86
  "-finetuned_bert_classifier_checkpoint",
 
98
  result[key]=value
99
  else:
100
  result[key]=float(value)
 
 
101
  # Create a plot
102
  with open("roc_data.pkl", "rb") as f:
103
  fpr, tpr, _ = pickle.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
+ roc_auc = auc(fpr, tpr)
106
+ fig, ax = plt.subplots()
107
+ ax.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
108
+ ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
109
+ ax.set(xlabel='False Positive Rate', ylabel='True Positive Rate', title=f'ROC Curve: {model_name}')
110
+ ax.legend(loc="lower right")
111
+ ax.grid()
112
 
113
  # Save plot to a file
114
+ plot_path = "plot.png"
115
+ fig.savefig(plot_path)
116
+ plt.close(fig)
 
 
 
 
117
  progress(1.0)
118
  # Prepare text output
119
  text_output = f"Model: {model_name}\nResult:\n{result}"
120
  # Prepare text output with HTML formatting
121
  text_output = f"""
122
+ Model: {model_name}\n
123
+ Result Summary:\n
124
+ -----------------\n
125
+ Precision: {result['precisions']:.2f}\n
126
+ Recall: {result['recalls']:.2f}\n
127
+ Time Taken: {result['time_taken_from_start']:.2f} seconds\n
128
+ Total Schools in test: {len(unique_schools):.4f}\n
129
+ Total Schools taken: {len(random_schools):.4f}\n
130
+ High grad schools: {len(high_sample):.4f}\n
131
+ Low grad schools: {len(low_sample):.4f}\n
132
+ -----------------\n
133
+ Note: The ROC Curve is also displayed for the evaluation.
 
134
  """
135
+ return text_output,plot_path
136
 
137
  # List of models for the dropdown menu
138
 
139
+ models = ["High Graduated Schools", "Low Graduated Schools", "Full Set"]
 
 
 
 
 
 
 
 
 
 
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  # Create the Gradio interface
142
+ with gr.Blocks(css="""
143
+ body {
144
+ background-color: #1e1e1e!important;
145
+ font-family: 'Arial', sans-serif;
146
+ color: #f5f5f5!important;;
147
+ }
148
+ .gradio-container {
149
+ max-width: 850px!important;
150
+ margin: 0 auto!important;;
151
+ padding: 20px!important;;
152
+ background-color: #292929!important;
153
+ border-radius: 10px;
154
+ box-shadow: 0 4px 20px rgba(0, 0, 0, 0.2);
155
+ }
156
+ .gradio-container-4-44-0 .prose h1 {
157
+ font-size: var(--text-xxl);
158
+ color: #ffffff!important;
159
+ }
160
+ #title {
161
+ color: white!important;
162
+ font-size: 2.3em;
163
+ font-weight: bold;
164
+ text-align: center!important;
165
+ margin-bottom: 20px;
166
+ }
167
+ .description {
168
+ text-align: center;
169
+ font-size: 1.1em;
170
+ color: #bfbfbf;
171
+ margin-bottom: 30px;
172
+ }
173
+ .file-box {
174
+ max-width: 180px;
175
+ padding: 5px;
176
+ background-color: #444!important;
177
+ border: 1px solid #666!important;
178
+ border-radius: 6px;
179
+ height: 80px!important;;
180
+ margin: 0 auto!important;;
181
+ text-align: center;
182
+ color: transparent;
183
+ }
184
+ .file-box span {
185
+ color: #f5f5f5!important;
186
+ font-size: 1em;
187
+ line-height: 45px; /* Vertically center text */
188
+ }
189
+ .dropdown-menu {
190
+ max-width: 220px;
191
+ margin: 0 auto!important;
192
+ background-color: #444!important;
193
+ color:#444!important;
194
+ border-radius: 6px;
195
+ padding: 8px;
196
+ font-size: 1.1em;
197
+ border: 1px solid #666;
198
+ }
199
+ .button {
200
+ background-color: #4CAF50!important;
201
+ color: white!important;
202
+ font-size: 1.1em;
203
+ padding: 10px 25px;
204
+ border-radius: 6px;
205
+ cursor: pointer;
206
+ transition: background-color 0.2s ease-in-out;
207
+ }
208
+ .button:hover {
209
+ background-color: #45a049!important;
210
+ }
211
+ .output-text {
212
+ background-color: #333!important;
213
+ padding: 12px;
214
+ border-radius: 8px;
215
+ border: 1px solid #666;
216
+ font-size: 1.1em;
217
+ }
218
+ .footer {
219
+ text-align: center;
220
+ margin-top: 50px;
221
+ font-size: 0.9em;
222
+ color: #b0b0b0;
223
+ }
224
+ .svelte-12ioyct .wrap {
225
+ display: none !important;
226
  }
227
+ .file-label-text {
228
+ display: none !important;
 
 
 
 
229
  }
230
 
231
+ div.svelte-sfqy0y {
232
+ display: flex;
233
+ flex-direction: inherit;
234
+ flex-wrap: wrap;
235
+ gap: var(--form-gap-width);
236
+ box-shadow: var(--block-shadow);
237
+ border: var(--block-border-width) solid var(--border-color-primary);
238
+ border-radius: var(--block-radius);
239
+ background: #1f2937!important;
240
+ overflow-y: hidden;
241
  }
242
 
243
+ .block.svelte-12cmxck {
244
+ position: relative;
245
+ margin: 0;
246
+ box-shadow: var(--block-shadow);
247
+ border-width: var(--block-border-width);
248
+ border-color: var(--block-border-color);
249
+ border-radius: var(--block-radius);
250
+ background: #1f2937!important;
251
+ width: 100%;
252
+ line-height: var(--line-sm);
253
  }
254
 
255
+ .svelte-12ioyct .wrap {
256
+ display: none !important;
 
 
 
 
257
  }
258
+ .file-label-text {
259
+ display: none !important;
 
 
 
 
 
 
 
 
 
260
  }
261
+ input[aria-label="file upload"] {
262
+ display: none !important;
 
 
 
 
263
  }
264
 
265
+ gradio-app .gradio-container.gradio-container-4-44-0 .contain .file-box span {
266
+ font-size: 1em;
267
+ line-height: 45px;
268
+ color: #1f2937 !important;
 
 
 
 
 
 
 
 
 
 
 
 
269
  }
270
+ .wrap.svelte-12ioyct {
 
 
271
  display: flex;
272
+ flex-direction: column;
273
+ justify-content: center;
274
  align-items: center;
275
+ min-height: var(--size-60);
276
+ color: #1f2937 !important;
277
+ line-height: var(--line-md);
278
+ height: 100%;
279
+ padding-top: var(--size-3);
 
 
 
 
 
 
 
 
 
 
 
280
  text-align: center;
281
+ margin: auto var(--spacing-lg);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
  }
283
+ span.svelte-1gfkn6j:not(.has-info) {
284
+ margin-bottom: var(--spacing-lg);
285
+ color: white!important;
 
 
 
 
 
286
  }
287
+ label.float.svelte-1b6s6s {
288
+ position: relative!important;
289
+ top: var(--block-label-margin);
290
+ left: var(--block-label-margin);
 
 
 
 
 
 
 
 
 
 
291
  }
292
+ label.svelte-1b6s6s {
293
+ display: inline-flex;
294
+ align-items: center;
295
+ z-index: var(--layer-2);
296
+ box-shadow: var(--block-label-shadow);
297
+ border: var(--block-label-border-width) solid var(--border-color-primary);
298
+ border-top: none;
299
+ border-left: none;
300
+ border-radius: var(--block-label-radius);
301
+ background: rgb(120 151 180)!important;
302
+ padding: var(--block-label-padding);
303
+ pointer-events: none;
304
+ color: #1f2937!important;
305
+ font-weight: var(--block-label-text-weight);
306
+ font-size: var(--block-label-text-size);
307
+ line-height: var(--line-sm);
308
+ }
309
+ .file.svelte-18wv37q.svelte-18wv37q {
310
+ display: block!important;
311
+ width: var(--size-full);
312
+ }
313
+
314
+ tbody.svelte-18wv37q>tr.svelte-18wv37q:nth-child(odd) {
315
+ background: ##7897b4!important;
316
+ color: white;
317
+ background: #aca7b2;
318
+ }
319
+ .gradio-container-4-31-4 .prose h1, .gradio-container-4-31-4 .prose h2, .gradio-container-4-31-4 .prose h3, .gradio-container-4-31-4 .prose h4, .gradio-container-4-31-4 .prose h5 {
320
+
321
+ color: white;
322
+ """) as demo:
323
+ gr.Markdown("<h1 id='title'>ASTRA</h1>", elem_id="title")
324
+ gr.Markdown("<p class='description'>Upload a .txt file and select a model from the dropdown menu.</p>")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
 
326
  with gr.Row():
327
+ file_input = gr.File(label="Upload a test file", file_types=['.txt'], elem_classes="file-box")
328
+ label_input = gr.File(label="Upload test labels", file_types=['.txt'], elem_classes="file-box")
329
 
330
+ info_input = gr.File(label="Upload test info", file_types=['.txt'], elem_classes="file-box")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
 
332
+ model_dropdown = gr.Dropdown(choices=models, label="Select Finetune Task", elem_classes="dropdown-menu")
 
 
 
 
 
 
 
 
 
333
 
334
+
335
+ increment_slider = gr.Slider(minimum=1, maximum=100, step=1, label="Schools Percentage", value=1)
336
+
337
  with gr.Row():
338
+ output_text = gr.Textbox(label="Output Text")
339
+ output_image = gr.Image(label="Output Plot")
 
340
 
341
+ btn = gr.Button("Submit")
342
 
343
+ btn.click(fn=process_file, inputs=[file_input,label_input,info_input,model_dropdown,increment_slider], outputs=[output_text,output_image])
 
 
 
 
344
 
345
 
346
  # Launch the app
distinguish_high_low_label.ipynb DELETED
@@ -1,553 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 27,
6
- "id": "960bac80-51c7-4e9f-ad2d-84cd6c710f98",
7
- "metadata": {},
8
- "outputs": [],
9
- "source": [
10
- "import pickle\n",
11
- "import pandas as pd\n",
12
- "from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, roc_curve, roc_auc_score,auc"
13
- ]
14
- },
15
- {
16
- "cell_type": "code",
17
- "execution_count": 3,
18
- "id": "a34f21d0-0854-4a54-8f93-67718b2f969e",
19
- "metadata": {},
20
- "outputs": [],
21
- "source": [
22
- "file_path = \"roc_data2.pkl\"\n",
23
- "\n",
24
- "# Open and load the pickle file\n",
25
- "with open(file_path, 'rb') as file:\n",
26
- " data = pickle.load(file)\n",
27
- "\n",
28
- "\n",
29
- "# Print or use the data\n",
30
- "# data[2]"
31
- ]
32
- },
33
- {
34
- "cell_type": "code",
35
- "execution_count": 4,
36
- "id": "f9febed4-ce50-4e30-96ea-4b538ce2f9a1",
37
- "metadata": {},
38
- "outputs": [],
39
- "source": [
40
- "inc_slider=1\n",
41
- "parent_location=\"ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/\"\n",
42
- "test_info_location=parent_location+\"fullTest/test_info.txt\"\n",
43
- "test_location=parent_location+\"fullTest/test.txt\"\n",
44
- "test_info = pd.read_csv(test_info_location, sep=',', header=None, engine='python')\n",
45
- "grad_rate_data = pd.DataFrame(pd.read_pickle('school_grduation_rate.pkl'),columns=['school_number','grad_rate']) # Load the grad_rate data\n",
46
- "\n",
47
- "# Step 1: Extract unique school numbers from test_info\n",
48
- "unique_schools = test_info[0].unique()\n",
49
- "\n",
50
- "# Step 2: Filter the grad_rate_data using the unique school numbers\n",
51
- "schools = grad_rate_data[grad_rate_data['school_number'].isin(unique_schools)]\n",
52
- "\n",
53
- "# Define a threshold for high and low graduation rates (adjust as needed)\n",
54
- "grad_rate_threshold = 0.9 \n",
55
- "\n",
56
- "# Step 4: Divide schools into high and low graduation rate groups\n",
57
- "high_grad_schools = schools[schools['grad_rate'] >= grad_rate_threshold]['school_number'].unique()\n",
58
- "low_grad_schools = schools[schools['grad_rate'] < grad_rate_threshold]['school_number'].unique()\n",
59
- "\n",
60
- "# Step 5: Sample percentage of schools from each group\n",
61
- "high_sample = pd.Series(high_grad_schools).sample(frac=inc_slider/100, random_state=1).tolist()\n",
62
- "low_sample = pd.Series(low_grad_schools).sample(frac=inc_slider/100, random_state=1).tolist()\n",
63
- "\n",
64
- "# Step 6: Combine the sampled schools\n",
65
- "random_schools = high_sample + low_sample\n",
66
- "\n",
67
- "# Step 7: Get indices for the sampled schools\n",
68
- "indices = test_info[test_info[0].isin(random_schools)].index.tolist()\n",
69
- "\n"
70
- ]
71
- },
72
- {
73
- "cell_type": "code",
74
- "execution_count": 5,
75
- "id": "fdfdf4b6-2752-4a21-9880-869af69f20cf",
76
- "metadata": {},
77
- "outputs": [],
78
- "source": [
79
- "high_indices = test_info[(test_info[0].isin(high_sample))].index.tolist()\n",
80
- "low_indices = test_info[(test_info[0].isin(low_sample))].index.tolist()"
81
- ]
82
- },
83
- {
84
- "cell_type": "code",
85
- "execution_count": 6,
86
- "id": "a79a4598-5702-4cc8-9f07-8e18fdda648b",
87
- "metadata": {},
88
- "outputs": [
89
- {
90
- "data": {
91
- "text/plain": [
92
- "997"
93
- ]
94
- },
95
- "execution_count": 6,
96
- "metadata": {},
97
- "output_type": "execute_result"
98
- }
99
- ],
100
- "source": [
101
- "len(high_indices)+len(low_indices)\n"
102
- ]
103
- },
104
- {
105
- "cell_type": "code",
106
- "execution_count": 7,
107
- "id": "4707f3e6-2f44-46d8-ad8c-b6c244f693af",
108
- "metadata": {},
109
- "outputs": [
110
- {
111
- "data": {
112
- "text/html": [
113
- "<div>\n",
114
- "<style scoped>\n",
115
- " .dataframe tbody tr th:only-of-type {\n",
116
- " vertical-align: middle;\n",
117
- " }\n",
118
- "\n",
119
- " .dataframe tbody tr th {\n",
120
- " vertical-align: top;\n",
121
- " }\n",
122
- "\n",
123
- " .dataframe thead th {\n",
124
- " text-align: right;\n",
125
- " }\n",
126
- "</style>\n",
127
- "<table border=\"1\" class=\"dataframe\">\n",
128
- " <thead>\n",
129
- " <tr style=\"text-align: right;\">\n",
130
- " <th></th>\n",
131
- " <th>0</th>\n",
132
- " </tr>\n",
133
- " </thead>\n",
134
- " <tbody>\n",
135
- " <tr>\n",
136
- " <th>5342</th>\n",
137
- " <td>PercentChange-0\\tNumeratorQuantity1-0\\tNumerat...</td>\n",
138
- " </tr>\n",
139
- " <tr>\n",
140
- " <th>5343</th>\n",
141
- " <td>PercentChange-0\\tNumeratorQuantity2-0\\tNumerat...</td>\n",
142
- " </tr>\n",
143
- " <tr>\n",
144
- " <th>5344</th>\n",
145
- " <td>PercentChange-0\\tNumeratorQuantity2-0\\tNumerat...</td>\n",
146
- " </tr>\n",
147
- " <tr>\n",
148
- " <th>5345</th>\n",
149
- " <td>PercentChange-0\\tNumeratorQuantity2-2\\tNumerat...</td>\n",
150
- " </tr>\n",
151
- " <tr>\n",
152
- " <th>5346</th>\n",
153
- " <td>PercentChange-0\\tNumeratorQuantity2-0\\tDenomin...</td>\n",
154
- " </tr>\n",
155
- " <tr>\n",
156
- " <th>...</th>\n",
157
- " <td>...</td>\n",
158
- " </tr>\n",
159
- " <tr>\n",
160
- " <th>113359</th>\n",
161
- " <td>PercentChange-0\\tNumeratorQuantity2-2\\tNumerat...</td>\n",
162
- " </tr>\n",
163
- " <tr>\n",
164
- " <th>113360</th>\n",
165
- " <td>PercentChange-0\\tNumeratorQuantity2-0\\tNumerat...</td>\n",
166
- " </tr>\n",
167
- " <tr>\n",
168
- " <th>113361</th>\n",
169
- " <td>PercentChange-0\\tNumeratorQuantity2-0\\tNumerat...</td>\n",
170
- " </tr>\n",
171
- " <tr>\n",
172
- " <th>113362</th>\n",
173
- " <td>PercentChange-0\\tNumeratorQuantity2-0\\tNumerat...</td>\n",
174
- " </tr>\n",
175
- " <tr>\n",
176
- " <th>113363</th>\n",
177
- " <td>PercentChange-0\\tNumeratorQuantity2-0\\tNumerat...</td>\n",
178
- " </tr>\n",
179
- " </tbody>\n",
180
- "</table>\n",
181
- "<p>997 rows × 1 columns</p>\n",
182
- "</div>"
183
- ],
184
- "text/plain": [
185
- " 0\n",
186
- "5342 PercentChange-0\\tNumeratorQuantity1-0\\tNumerat...\n",
187
- "5343 PercentChange-0\\tNumeratorQuantity2-0\\tNumerat...\n",
188
- "5344 PercentChange-0\\tNumeratorQuantity2-0\\tNumerat...\n",
189
- "5345 PercentChange-0\\tNumeratorQuantity2-2\\tNumerat...\n",
190
- "5346 PercentChange-0\\tNumeratorQuantity2-0\\tDenomin...\n",
191
- "... ...\n",
192
- "113359 PercentChange-0\\tNumeratorQuantity2-2\\tNumerat...\n",
193
- "113360 PercentChange-0\\tNumeratorQuantity2-0\\tNumerat...\n",
194
- "113361 PercentChange-0\\tNumeratorQuantity2-0\\tNumerat...\n",
195
- "113362 PercentChange-0\\tNumeratorQuantity2-0\\tNumerat...\n",
196
- "113363 PercentChange-0\\tNumeratorQuantity2-0\\tNumerat...\n",
197
- "\n",
198
- "[997 rows x 1 columns]"
199
- ]
200
- },
201
- "execution_count": 7,
202
- "metadata": {},
203
- "output_type": "execute_result"
204
- }
205
- ],
206
- "source": [
207
- "# Load the test file and select rows based on indices\n",
208
- "test = pd.read_csv(test_location, sep=',', header=None, engine='python')\n",
209
- "selected_rows_df2 = test.loc[indices]\n",
210
- "selected_rows_df2"
211
- ]
212
- },
213
- {
214
- "cell_type": "code",
215
- "execution_count": 8,
216
- "id": "1d0c3d49-061f-486b-9c19-cf20945f3207",
217
- "metadata": {},
218
- "outputs": [
219
- {
220
- "data": {
221
- "text/plain": [
222
- "997"
223
- ]
224
- },
225
- "execution_count": 8,
226
- "metadata": {},
227
- "output_type": "execute_result"
228
- }
229
- ],
230
- "source": [
231
- "graduation_groups = [\n",
232
- " 'high' if idx in high_indices else 'low' for idx in selected_rows_df2.index\n",
233
- "]\n",
234
- "# graduation_groups\n",
235
- "len(graduation_groups)"
236
- ]
237
- },
238
- {
239
- "cell_type": "code",
240
- "execution_count": 9,
241
- "id": "d2508a0f-e5ca-432e-b99b-481ea4536d4d",
242
- "metadata": {},
243
- "outputs": [
244
- {
245
- "data": {
246
- "text/plain": [
247
- "997"
248
- ]
249
- },
250
- "execution_count": 9,
251
- "metadata": {},
252
- "output_type": "execute_result"
253
- }
254
- ],
255
- "source": [
256
- "opt_task_groups = ['opt_task1' if test_info.loc[idx, 6] == 0 else 'opt_task2' for idx in selected_rows_df2.index]\n",
257
- "len(opt_task_groups)"
258
- ]
259
- },
260
- {
261
- "cell_type": "code",
262
- "execution_count": 10,
263
- "id": "ad0ce4a1-27fa-4867-8061-4054dbb340df",
264
- "metadata": {},
265
- "outputs": [],
266
- "source": [
267
- "t_label=data[0]\n",
268
- "p_label=data[1]"
269
- ]
270
- },
271
- {
272
- "cell_type": "code",
273
- "execution_count": 12,
274
- "id": "a4f4a2b9-3134-42ac-871b-4e117098cd0e",
275
- "metadata": {},
276
- "outputs": [],
277
- "source": [
278
- "# Step 1: Align graduation_group, t_label, and p_label\n",
279
- "aligned_labels = list(zip(graduation_groups, t_label, p_label))\n",
280
- "opt_task_aligned = list(zip(opt_task_groups, t_label, p_label))\n",
281
- "# Step 2: Separate the labels for high and low groups\n",
282
- "high_t_labels = [t for grad, t, p in aligned_labels if grad == 'high']\n",
283
- "low_t_labels = [t for grad, t, p in aligned_labels if grad == 'low']\n",
284
- "\n",
285
- "high_p_labels = [p for grad, t, p in aligned_labels if grad == 'high']\n",
286
- "low_p_labels = [p for grad, t, p in aligned_labels if grad == 'low']\n",
287
- "\n",
288
- "\n",
289
- "opt_task1_t_labels = [t for task, t, p in opt_task_aligned if task == 'opt_task1']\n",
290
- "opt_task1_p_labels = [p for task, t, p in opt_task_aligned if task == 'opt_task1']\n",
291
- "\n",
292
- "opt_task2_t_labels = [t for task, t, p in opt_task_aligned if task == 'opt_task2']\n",
293
- "opt_task2_p_labels = [p for task, t, p in opt_task_aligned if task == 'opt_task2']\n"
294
- ]
295
- },
296
- {
297
- "cell_type": "code",
298
- "execution_count": 15,
299
- "id": "74cda932-ce98-4ad5-9c29-a54bdc4ee086",
300
- "metadata": {},
301
- "outputs": [
302
- {
303
- "name": "stdout",
304
- "output_type": "stream",
305
- "text": [
306
- "opt_task1 ROC-AUC: 0.7592686234399062\n",
307
- "opt_task2 ROC-AUC: 0.7268598353289777\n"
308
- ]
309
- }
310
- ],
311
- "source": [
312
- "\n",
313
- "opt_task1_roc_auc = roc_auc_score(opt_task1_t_labels, opt_task1_p_labels) if len(set(opt_task1_t_labels)) > 1 else None\n",
314
- "opt_task2_roc_auc = roc_auc_score(opt_task2_t_labels, opt_task2_p_labels) if len(set(opt_task2_t_labels)) > 1 else None\n",
315
- "\n",
316
- "print(f\"opt_task1 ROC-AUC: {opt_task1_roc_auc}\")\n",
317
- "print(f\"opt_task2 ROC-AUC: {opt_task2_roc_auc}\")"
318
- ]
319
- },
320
- {
321
- "cell_type": "code",
322
- "execution_count": 50,
323
- "id": "c8e34660-83d0-46a1-a218-95d609e11729",
324
- "metadata": {},
325
- "outputs": [
326
- {
327
- "data": {
328
- "text/plain": [
329
- "997"
330
- ]
331
- },
332
- "execution_count": 50,
333
- "metadata": {},
334
- "output_type": "execute_result"
335
- }
336
- ],
337
- "source": [
338
- "len(low_t_labels)+len(high_t_labels)"
339
- ]
340
- },
341
- {
342
- "cell_type": "code",
343
- "execution_count": 13,
344
- "id": "c11050db-2636-4c50-9cd4-b9943e5cee83",
345
- "metadata": {},
346
- "outputs": [],
347
- "source": []
348
- },
349
- {
350
- "cell_type": "code",
351
- "execution_count": 16,
352
- "id": "e1309e93-7063-4f48-bbc7-11a0d449c34e",
353
- "metadata": {},
354
- "outputs": [
355
- {
356
- "name": "stdout",
357
- "output_type": "stream",
358
- "text": [
359
- "ROC-AUC Score for High Graduation Rate Group: 0.675\n",
360
- "ROC-AUC Score for Low Graduation Rate Group: 0.7489795918367347\n"
361
- ]
362
- }
363
- ],
364
- "source": [
365
- "high_roc_auc = roc_auc_score(high_t_labels, high_p_labels) if len(set(high_t_labels)) > 1 else None\n",
366
- "low_roc_auc = roc_auc_score(low_t_labels, low_p_labels) if len(set(low_t_labels)) > 1 else None\n",
367
- "\n",
368
- "print(\"ROC-AUC Score for High Graduation Rate Group:\", high_roc_auc)\n",
369
- "print(\"ROC-AUC Score for Low Graduation Rate Group:\", low_roc_auc)"
370
- ]
371
- },
372
- {
373
- "cell_type": "code",
374
- "execution_count": 21,
375
- "id": "a99e7812-817d-4f9f-b6fa-1a58aa3a34dc",
376
- "metadata": {},
377
- "outputs": [
378
- {
379
- "name": "stdout",
380
- "output_type": "stream",
381
- "text": [
382
- "Task Analysis Summary:\n",
383
- "-----------------------\n",
384
- "Ideal Task = OptionalTask_1:\n",
385
- " Only OptionalTask_1 done: 22501\n",
386
- " Only OptionalTask_2 done: 20014\n",
387
- " Both done: 24854\n",
388
- " None done: 38\n",
389
- "Ideal Task = OptionalTask_2:\n",
390
- " Only OptionalTask_1 done: 12588\n",
391
- " Only OptionalTask_2 done: 18942\n",
392
- " Both done: 15147\n",
393
- " None done: 78\n",
394
- "\n"
395
- ]
396
- }
397
- ],
398
- "source": [
399
- "def analyze_row(row):\n",
400
- " # Split the row into fields\n",
401
- " fields = row.split(\"\\t\")\n",
402
- "\n",
403
- " # Define tasks for OptionalTask_1, OptionalTask_2, and FinalAnswer\n",
404
- " optional_task_1_subtasks = [\"DenominatorFactor\", \"NumeratorFactor\", \"EquationAnswer\"]\n",
405
- " optional_task_2_subtasks = [\n",
406
- " \"FirstRow2:1\", \"FirstRow2:2\", \"FirstRow1:1\", \"FirstRow1:2\", \n",
407
- " \"SecondRow\", \"ThirdRow\"\n",
408
- " ]\n",
409
- "\n",
410
- " # Helper function to evaluate task attempts\n",
411
- " def evaluate_tasks(fields, tasks):\n",
412
- " task_status = {}\n",
413
- " for task in tasks:\n",
414
- " relevant_attempts = [f for f in fields if task in f]\n",
415
- " if any(\"OK\" in attempt for attempt in relevant_attempts):\n",
416
- " task_status[task] = \"Attempted (Successful)\"\n",
417
- " elif any(\"ERROR\" in attempt for attempt in relevant_attempts):\n",
418
- " task_status[task] = \"Attempted (Error)\"\n",
419
- " elif any(\"JIT\" in attempt for attempt in relevant_attempts):\n",
420
- " task_status[task] = \"Attempted (JIT)\"\n",
421
- " else:\n",
422
- " task_status[task] = \"Unattempted\"\n",
423
- " return task_status\n",
424
- "\n",
425
- " # Evaluate tasks for each category\n",
426
- " optional_task_1_status = evaluate_tasks(fields, optional_task_1_subtasks)\n",
427
- " optional_task_2_status = evaluate_tasks(fields, optional_task_2_subtasks)\n",
428
- "\n",
429
- " # Check if tasks have any successful attempt\n",
430
- " opt1_done = any(status == \"Attempted (Successful)\" for status in optional_task_1_status.values())\n",
431
- " opt2_done = any(status == \"Attempted (Successful)\" for status in optional_task_2_status.values())\n",
432
- "\n",
433
- " return opt1_done, opt2_done\n",
434
- "\n",
435
- "# Read data from test_info.txt\n",
436
- "# Read data from test_info.txt\n",
437
- "with open(test_info_location, \"r\") as file:\n",
438
- " data = file.readlines()\n",
439
- "\n",
440
- "# Assuming test_info[7] is a list with ideal tasks for each instance\n",
441
- "ideal_tasks = test_info[6] # A list where each element is either 1 or 2\n",
442
- "\n",
443
- "# Initialize counters\n",
444
- "task_counts = {\n",
445
- " 1: {\"only_opt1\": 0, \"only_opt2\": 0, \"both\": 0,\"none\":0},\n",
446
- " 2: {\"only_opt1\": 0, \"only_opt2\": 0, \"both\": 0,\"none\":0}\n",
447
- "}\n",
448
- "\n",
449
- "# Analyze rows\n",
450
- "for i, row in enumerate(data):\n",
451
- " row = row.strip()\n",
452
- " if not row:\n",
453
- " continue\n",
454
- "\n",
455
- " ideal_task = ideal_tasks[i] # Get the ideal task for the current row\n",
456
- " opt1_done, opt2_done = analyze_row(row)\n",
457
- "\n",
458
- " if ideal_task == 0:\n",
459
- " if opt1_done and not opt2_done:\n",
460
- " task_counts[1][\"only_opt1\"] += 1\n",
461
- " elif not opt1_done and opt2_done:\n",
462
- " task_counts[1][\"only_opt2\"] += 1\n",
463
- " elif opt1_done and opt2_done:\n",
464
- " task_counts[1][\"both\"] += 1\n",
465
- " else:\n",
466
- " task_counts[1][\"none\"] +=1\n",
467
- " elif ideal_task == 1:\n",
468
- " if opt1_done and not opt2_done:\n",
469
- " task_counts[2][\"only_opt1\"] += 1\n",
470
- " elif not opt1_done and opt2_done:\n",
471
- " task_counts[2][\"only_opt2\"] += 1\n",
472
- " elif opt1_done and opt2_done:\n",
473
- " task_counts[2][\"both\"] += 1\n",
474
- " else:\n",
475
- " task_counts[2][\"none\"] +=1\n",
476
- "\n",
477
- "# Create a string output for results\n",
478
- "output_summary = \"Task Analysis Summary:\\n\"\n",
479
- "output_summary += \"-----------------------\\n\"\n",
480
- "\n",
481
- "for ideal_task, counts in task_counts.items():\n",
482
- " output_summary += f\"Ideal Task = OptionalTask_{ideal_task}:\\n\"\n",
483
- " output_summary += f\" Only OptionalTask_1 done: {counts['only_opt1']}\\n\"\n",
484
- " output_summary += f\" Only OptionalTask_2 done: {counts['only_opt2']}\\n\"\n",
485
- " output_summary += f\" Both done: {counts['both']}\\n\"\n",
486
- " output_summary += f\" None done: {counts['none']}\\n\"\n",
487
- "\n",
488
- "print(output_summary)\n"
489
- ]
490
- },
491
- {
492
- "cell_type": "code",
493
- "execution_count": 23,
494
- "id": "3630406c-859a-43ab-a569-67d577cc9bf6",
495
- "metadata": {},
496
- "outputs": [],
497
- "source": [
498
- "import gradio as gr\n",
499
- "from matplotlib.figure import Figure"
500
- ]
501
- },
502
- {
503
- "cell_type": "code",
504
- "execution_count": 28,
505
- "id": "99833638-882d-4c75-bcc3-031e39cfb5a7",
506
- "metadata": {},
507
- "outputs": [],
508
- "source": [
509
- "with open(\"roc_data.pkl\", \"rb\") as f:\n",
510
- " fpr, tpr, _ = pickle.load(f)\n",
511
- "roc_auc = auc(fpr, tpr)\n",
512
- "\n",
513
- "# Create a matplotlib figure\n",
514
- "fig = Figure()\n",
515
- "ax = fig.add_subplot(1, 1, 1)\n",
516
- "ax.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')\n",
517
- "ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')\n",
518
- "ax.set(xlabel='False Positive Rate', ylabel='True Positive Rate', title=f'Receiver Operating Curve (ROC)')\n",
519
- "ax.legend(loc=\"lower right\")\n",
520
- "ax.grid()"
521
- ]
522
- },
523
- {
524
- "cell_type": "code",
525
- "execution_count": null,
526
- "id": "6eb3dece-5b33-4223-af9a-6b999bb2305b",
527
- "metadata": {},
528
- "outputs": [],
529
- "source": []
530
- }
531
- ],
532
- "metadata": {
533
- "kernelspec": {
534
- "display_name": "Python 3 (ipykernel)",
535
- "language": "python",
536
- "name": "python3"
537
- },
538
- "language_info": {
539
- "codemirror_mode": {
540
- "name": "ipython",
541
- "version": 3
542
- },
543
- "file_extension": ".py",
544
- "mimetype": "text/x-python",
545
- "name": "python",
546
- "nbconvert_exporter": "python",
547
- "pygments_lexer": "ipython3",
548
- "version": "3.12.4"
549
- }
550
- },
551
- "nbformat": 4,
552
- "nbformat_minor": 5
553
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fullTest/test.txt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:2a479561b801a43249b6a8aceed5f32d16cec3d2f40956ed02640b6dcab0bdfe
3
- size 21353853
 
 
 
 
fullTest/test_info.txt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:dbb182b48eecce59c4e61f82a23d8af2866d9327f0543aca3546880fdb0d6003
3
- size 166442240
 
 
 
 
fullTest/test_label.txt DELETED
The diff for this file is too large to render. See raw diff
 
new_test_saved_finetuned_model.py CHANGED
@@ -221,12 +221,9 @@ class BERTFineTuneTrainer:
221
  for key, value in final_msg.items():
222
  file.write(f"{key}: {value}\n")
223
  print(final_msg)
224
- # print(type(plabels),type(tlabels),plabels,tlabels)
225
  fpr, tpr, thresholds = roc_curve(tlabels, positive_class_probs)
226
  with open("roc_data.pkl", "wb") as f:
227
  pickle.dump((fpr, tpr, thresholds), f)
228
- with open("roc_data2.pkl", "wb") as f:
229
- pickle.dump((tlabels,positive_class_probs), f)
230
  print(final_msg)
231
  f.close()
232
  with open(self.log_folder_path+f"/log_{phase}_finetuned_info.txt", 'a') as f1:
@@ -429,7 +426,6 @@ class BERTFineTuneCalibratedTrainer:
429
  auc_score = roc_auc_score(tlabels, positive_class_probs)
430
  end_time = time.time()
431
  final_msg = {
432
- "this one":"this one",
433
  "avg_loss": avg_loss / len(data_iter),
434
  "total_acc": total_correct * 100.0 / total_element,
435
  "precisions": precisions,
@@ -444,8 +440,7 @@ class BERTFineTuneCalibratedTrainer:
444
  with open("result.txt", 'w') as file:
445
  for key, value in final_msg.items():
446
  file.write(f"{key}: {value}\n")
447
- with open("plabels.txt","w") as file:
448
- file.write(plabels)
449
  print(final_msg)
450
  fpr, tpr, thresholds = roc_curve(tlabels, positive_class_probs)
451
  f.close()
 
221
  for key, value in final_msg.items():
222
  file.write(f"{key}: {value}\n")
223
  print(final_msg)
 
224
  fpr, tpr, thresholds = roc_curve(tlabels, positive_class_probs)
225
  with open("roc_data.pkl", "wb") as f:
226
  pickle.dump((fpr, tpr, thresholds), f)
 
 
227
  print(final_msg)
228
  f.close()
229
  with open(self.log_folder_path+f"/log_{phase}_finetuned_info.txt", 'a') as f1:
 
426
  auc_score = roc_auc_score(tlabels, positive_class_probs)
427
  end_time = time.time()
428
  final_msg = {
 
429
  "avg_loss": avg_loss / len(data_iter),
430
  "total_acc": total_correct * 100.0 / total_element,
431
  "precisions": precisions,
 
440
  with open("result.txt", 'w') as file:
441
  for key, value in final_msg.items():
442
  file.write(f"{key}: {value}\n")
443
+
 
444
  print(final_msg)
445
  fpr, tpr, thresholds = roc_curve(tlabels, positive_class_probs)
446
  f.close()
plot.png CHANGED
ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/fullTest/highGRschool10_/test.txt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:20a028aa7529f6c68f16ba09e038ef969ca61aa22ee1e41f5e0474883aabbddc
3
- size 24775790
 
 
 
 
ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/fullTest/highGRschool10_/test_info.txt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:f29c4c3585b70ef5a1fc0c107d9d96c63b7adae0659789b90f5bfab97df57026
3
- size 123225375
 
 
 
 
ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/fullTest/highGRschool10_/test_label.txt DELETED
The diff for this file is too large to render. See raw diff
 
ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/fullTest/test.txt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:35569d6f81ef85e6353f36912c1cb79bfb723fe7d2476e10afcb745c170c5130
3
- size 24672844
 
 
 
 
ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/fullTest/test_BKT.txt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:dec171098a0b444d3b8a3de8497345e8806440038756ce51a575314e6c414647
3
- size 20086086
 
 
 
 
ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/fullTest/test_info.txt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a6aadba0002bfdfde835b8837b3ff36cd84c64c3e23b6589ec1d002b4b62c2f4
3
- size 122629427
 
 
 
 
ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/fullTest/test_label.txt DELETED
The diff for this file is too large to render. See raw diff
 
ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/highGRschool10/test_label.txt CHANGED
The diff for this file is too large to render. See raw diff
 
ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/lowGRschoolAll/test.txt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e738c87fcbcc3e0362199ea2b7f9ef06093fb3f9e7a5f8c5ab828602e52230f9
3
- size 16005023
 
 
 
 
ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/lowGRschoolAll/test_info.txt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:ef4862f5c282efdfa49e13ed0f6cb344abcb7ae07fdfba535d48193bb8a3c1ed
3
- size 81939614
 
 
 
 
ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/lowGRschoolAll/test_label.txt DELETED
The diff for this file is too large to render. See raw diff
 
ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/test.txt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:2a479561b801a43249b6a8aceed5f32d16cec3d2f40956ed02640b6dcab0bdfe
3
- size 21353853
 
 
 
 
ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/test_info.txt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:dbb182b48eecce59c4e61f82a23d8af2866d9327f0543aca3546880fdb0d6003
3
- size 166442240
 
 
 
 
ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/test_label.txt DELETED
The diff for this file is too large to render. See raw diff
 
result.txt CHANGED
@@ -1,7 +1,7 @@
1
- avg_loss: 0.5841353535652161
2
- total_acc: 69.00702106318957
3
- precisions: 0.7236623191454734
4
- recalls: 0.6900702106318957
5
- f1_scores: 0.6802420656474512
6
- time_taken_from_start: 25.420082330703735
7
- auc_score: 0.7457100293916334
 
1
+ avg_loss: 0.5730699896812439
2
+ total_acc: 69.52861952861953
3
+ precisions: 0.7336375047795977
4
+ recalls: 0.6952861952861953
5
+ f1_scores: 0.6858177547541179
6
+ time_taken_from_start: 28.49159860610962
7
+ auc_score: 0.7738852057033876
roc_data.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1eec496eaf0223a6c65d0367eb586c968edc655b7e4c601d35db358f8419047d
3
- size 9437
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c4af99c21a2122f6f4c4773439bbb77976243559acf78cd9b771f24d3ae9bdc
3
+ size 5930
roc_data2.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:41fa9d96833c12979f8495141ee61c0ba07d4a20c5fb5bc18a7f72bf4d15e8fd
3
- size 28023
 
 
 
 
selected_rows.txt CHANGED
The diff for this file is too large to render. See raw diff
 
test.txt ADDED
The diff for this file is too large to render. See raw diff
 
train.txt ADDED
The diff for this file is too large to render. See raw diff
 
train_info.txt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:ef4862f5c282efdfa49e13ed0f6cb344abcb7ae07fdfba535d48193bb8a3c1ed
3
- size 81939614
 
 
 
 
train_label.txt CHANGED
The diff for this file is too large to render. See raw diff