Vivek commited on
Commit
e1fc256
1 Parent(s): 42bbcd0

final changes

Browse files
Files changed (2) hide show
  1. src/piqa_predictions.csv +385 -0
  2. src/test_piqa.py +3 -3
src/piqa_predictions.csv ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ,predictions,permutation
2
+ 0,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[2288 625 1674 1729 1492 641 196 2138 1716 1790 1884 223 971 51
3
+ 735 1123]"
4
+ 1,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[ 762 1021 2749 1896 1118 1697 184 3024 1108 2816 2204 1316 1613 3018
5
+ 1799 1054]"
6
+ 2,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[ 632 11 1985 2129 788 1787 1571 994 2616 1449 2378 1378 1946 2930
7
+ 2281 2149]"
8
+ 3,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[2019 792 2842 2463 1689 2529 1186 749 1061 456 556 1834 253 336
9
+ 648 1977]"
10
+ 4,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[1536 3051 252 2002 1100 579 913 2321 1788 1683 1853 62 780 818
11
+ 2208 2131]"
12
+ 5,[1 1 0 0 1 1 1 1 1 1 1 1 1 1 1 1],"[1806 1235 19 2941 2181 862 2369 837 2819 2646 263 2650 2403 817
13
+ 2184 2773]"
14
+ 6,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[1608 2886 1940 2738 1860 1711 2224 345 2663 2528 1610 2008 1169 129
15
+ 806 711]"
16
+ 7,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[1046 809 2353 2111 3000 2955 1715 1654 136 1682 387 2703 2221 757
17
+ 2978 664]"
18
+ 8,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[2662 2755 2849 537 1395 589 2362 327 712 768 2371 2848 2788 449
19
+ 1591 2865]"
20
+ 9,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[1181 214 2594 2581 515 693 421 29 985 2609 2828 873 895 1019
21
+ 3 2392]"
22
+ 10,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[2442 2613 2564 889 2047 1116 2260 2817 529 4 1000 822 885 2718
23
+ 1890 475]"
24
+ 11,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[ 481 1523 322 975 1531 560 2120 2939 595 2756 444 1733 2066 1422
25
+ 787 1649]"
26
+ 12,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[ 751 1157 256 2136 2709 2720 282 2177 1997 2122 2722 286 2099 78
27
+ 2275 2875]"
28
+ 13,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[2723 531 1267 1379 2813 1 610 1269 789 2174 1907 2919 2113 340
29
+ 2 2911]"
30
+ 14,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[1780 1913 489 1930 2341 2382 1844 2486 173 1318 1585 389 1261 30
31
+ 854 2917]"
32
+ 15,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[2603 487 1713 2093 2130 2554 1839 1281 1601 2427 1107 2380 2333 2025
33
+ 2507 1479]"
34
+ 16,[1 1 0 1 1 1 1 0 1 1 1 1 1 1 1 1],"[ 840 653 2491 2357 2653 2394 1166 65 3049 93 1841 1420 1386 3009
35
+ 261 1102]"
36
+ 17,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[1757 1431 379 2948 728 1680 2635 1652 2255 1559 2273 857 2754 1285
37
+ 143 2541]"
38
+ 18,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[2540 1547 311 1690 2968 616 262 2644 2627 142 332 1926 1137 215
39
+ 543 2470]"
40
+ 19,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[1400 847 1374 888 2189 292 2092 2385 1953 2381 2637 172 2355 9
41
+ 245 2946]"
42
+ 20,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[ 67 2294 1791 2805 2452 2675 2033 1845 790 1180 2556 944 1866 220
43
+ 1519 1245]"
44
+ 21,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[1387 1055 1047 1858 397 1383 39 1745 2896 74 1179 1589 1535 2395
45
+ 2789 2966]"
46
+ 22,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[1572 1606 429 2571 726 1867 663 72 2976 2579 297 829 212 1253
47
+ 798 277]"
48
+ 23,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[2707 2352 1476 329 1371 1782 1620 957 1840 2927 1488 1425 1777 1276
49
+ 2782 621]"
50
+ 24,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[ 37 442 2312 1082 1868 1357 591 1359 781 259 374 2539 2710 1136
51
+ 2406 2012]"
52
+ 25,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[1891 2094 3066 707 2687 1201 2683 2445 346 1804 686 3045 2240 243
53
+ 2977 1740]"
54
+ 26,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[2276 1876 2164 2263 2947 2578 1964 1663 140 1148 2591 984 1370 2402
55
+ 2360 1421]"
56
+ 27,[1 1 0 1 1 1 1 1 1 0 1 1 1 1 1 1],"[2182 2430 932 2246 71 2981 407 2202 1875 57 2933 1209 95 2187
57
+ 371 1282]"
58
+ 28,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[1402 2314 395 725 503 3047 2667 1372 849 418 1010 2699 2429 2472
59
+ 86 2929]"
60
+ 29,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[2938 1404 1337 310 2322 1950 1925 1277 1230 963 1490 1963 2582 1614
61
+ 63 2553]"
62
+ 30,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[1389 1345 2155 633 1226 1144 1566 3063 1934 592 566 2387 1438 1159
63
+ 3015 178]"
64
+ 31,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[2293 553 958 194 1205 419 1673 2907 94 2568 594 1895 1494 2797
65
+ 3048 2542]"
66
+ 32,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[2692 2226 1455 863 1304 2359 1279 876 554 1229 1597 352 2471 1700
67
+ 404 2376]"
68
+ 33,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[ 966 1686 1881 867 231 842 2422 644 1541 679 859 1657 1351 1769
69
+ 1329 2342]"
70
+ 34,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[1644 2778 2913 2890 1086 254 1786 890 368 2085 1864 608 1609 2084
71
+ 3040 1855]"
72
+ 35,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[ 125 1477 2792 2851 540 2205 1615 811 1943 484 451 512 832 1078
73
+ 1691 2323]"
74
+ 36,[1 1 0 1 1 1 1 0 1 1 1 1 1 1 1 1],"[1210 1331 521 1117 799 2982 2954 2310 1149 3074 823 3064 1772 2327
75
+ 2102 111]"
76
+ 37,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[ 313 774 2854 1838 1445 1598 1439 82 833 1760 1222 1748 2685 2868
77
+ 2161 1001]"
78
+ 38,[1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1],"[1499 1984 96 1696 1756 2279 617 2437 933 1029 66 2053 1775 1432
79
+ 720 1264]"
80
+ 39,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[ 758 2268 1417 1687 1407 694 2940 2549 899 619 2159 3012 2467 2705
81
+ 2713 144]"
82
+ 40,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[ 709 602 432 1638 887 676 1058 2286 20 891 413 2063 526 866
83
+ 786 1765]"
84
+ 41,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[ 905 2520 1301 321 1309 2089 1701 3052 3081 428 1626 31 2623 2742
85
+ 2666 1632]"
86
+ 42,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[2282 2625 2648 2798 1981 1365 1223 1234 1120 2118 384 2884 2866 1579
87
+ 902 472]"
88
+ 43,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[1004 1848 2370 1085 795 2814 2806 1577 2473 1283 815 2967 2334 1453
89
+ 925 1919]"
90
+ 44,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[2152 1485 2799 2207 2958 593 2761 1768 1284 1558 2116 137 1900 2074
91
+ 1831 574]"
92
+ 45,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[ 344 880 2254 1406 2040 1274 315 2153 1461 2545 2388 674 2304 3075
93
+ 1521 2479]"
94
+ 46,[1 1 0 1 1 1 1 1 0 1 1 1 1 1 1 1],"[1607 2444 294 1936 2548 1921 1920 919 113 951 532 2030 1290 1568
95
+ 590 2741]"
96
+ 47,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[1527 2862 1849 565 324 539 1918 877 1944 304 1474 493 2555 2209
97
+ 426 1265]"
98
+ 48,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[1175 1336 830 3079 651 1041 307 1382 804 341 544 2227 721 2809
99
+ 1617 1562]"
100
+ 49,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[ 75 2105 2344 2787 2125 50 584 2711 2483 2121 1081 834 2831 1307
101
+ 1537 1288]"
102
+ 50,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[2511 1113 267 2165 2020 2952 21 2604 684 1730 904 1083 2557 620
103
+ 2589 753]"
104
+ 51,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[2306 1368 3058 893 772 1991 1418 1789 710 2489 2962 575 716 1974
105
+ 370 1064]"
106
+ 52,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[ 918 2058 2459 36 2639 2237 831 3021 2325 2630 613 836 462 1437
107
+ 249 1130]"
108
+ 53,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1],"[1339 1190 1948 2488 2280 115 576 467 224 1153 2397 2606 1685 997
109
+ 1892 2953]"
110
+ 54,[1 1 0 1 1 1 1 0 1 1 1 1 1 1 1 1],"[2198 2901 1384 2374 1814 1320 2559 25 1670 2454 2218 1034 342 1040
111
+ 1071 942]"
112
+ 55,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[1470 570 1738 133 1104 1323 146 1955 2678 2903 2061 1396 2000 1567
113
+ 2003 1998]"
114
+ 56,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[2267 959 64 717 514 1986 1723 2335 766 377 1143 2626 1507 2822
115
+ 631 2057]"
116
+ 57,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[2133 1851 677 2887 967 1241 2964 1008 2302 1141 1524 2836 1671 298
117
+ 634 2526]"
118
+ 58,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[1594 2620 767 2693 126 2587 2060 1575 734 2864 845 2994 2521 1127
119
+ 2339 2993]"
120
+ 59,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[1239 1927 2524 296 596 2611 2769 2042 400 2891 351 1656 1699 1611
121
+ 785 1636]"
122
+ 60,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[1193 1545 32 2052 1219 1731 2618 326 123 626 839 724 2299 1030
123
+ 525 1012]"
124
+ 61,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[1593 2001 2716 354 2570 1823 2795 2428 2173 2597 2446 2469 1949 777
125
+ 1018 800]"
126
+ 62,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[1503 227 886 2676 2700 1192 2062 689 2780 2049 2600 2970 989 2496
127
+ 2126 733]"
128
+ 63,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[ 6 2354 2960 1198 1658 2225 425 2904 2250 3023 2833 455 1022 3053
129
+ 1482 2641]"
130
+ 64,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[ 275 2236 1491 430 2044 1441 692 1154 1931 698 1428 1009 2959 2424
131
+ 567 1204]"
132
+ 65,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[1147 1511 2924 746 510 2337 2179 1662 1554 2777 802 1871 1909 1424
133
+ 938 8]"
134
+ 66,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[2169 1797 604 76 2830 217 2021 1165 1035 386 1822 1429 838 1549
135
+ 1306 2413]"
136
+ 67,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[1661 52 2717 2509 2517 3005 2112 1079 1299 928 2432 1989 1672 1781
137
+ 364 117]"
138
+ 68,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[ 681 1450 420 415 2835 2880 559 1352 2046 2984 1038 657 1514 392
139
+ 159 708]"
140
+ 69,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[1707 1530 1886 722 1710 1973 1837 218 1573 1291 468 848 2892 2668
141
+ 1522 599]"
142
+ 70,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[2140 1305 1709 1550 2850 1803 1543 3062 2192 1630 83 2925 193 1321
143
+ 41 2760]"
144
+ 71,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[ 939 803 1642 2197 372 2004 1624 1811 517 301 1774 2398 2265 740
145
+ 1529 1684]"
146
+ 72,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[1750 2628 422 1681 2363 2128 737 1941 2945 2815 2498 821 1534 209
147
+ 132 990]"
148
+ 73,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[3037 2143 2145 2171 2163 901 1497 2995 1216 2210 1080 15 1728 1353
149
+ 1182 1286]"
150
+ 74,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[1600 1442 1249 2196 2990 2067 1726 2936 953 2726 1939 856 112 779
151
+ 2421 2870]"
152
+ 75,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[ 700 1882 238 808 2073 213 1968 1189 385 643 1155 1348 424 1233
153
+ 3022 3003]"
154
+ 76,[1 1 0 1 1 1 1 1 0 1 1 1 1 0 1 1],"[2951 2379 2783 2583 2076 1912 581 2839 1451 1220 1214 360 2492 943
155
+ 796 1278]"
156
+ 77,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[1314 2767 349 232 1599 190 2987 2689 2320 1240 2736 2530 2100 1125
157
+ 1631 2083]"
158
+ 78,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[ 605 1324 2038 1072 1173 2330 1213 2248 1358 705 1363 1152 736 2558
159
+ 2812 2636]"
160
+ 79,[1 1 0 1 1 1 1 1 1 1 0 1 1 1 1 1],"[ 927 182 2468 308 1057 1397 962 2657 107 640 1168 2610 2495 2972
161
+ 3077 13]"
162
+ 80,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[ 465 846 1719 1990 1075 2847 2621 2811 2253 168 1459 2790 273 1448
163
+ 1273 763]"
164
+ 81,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[2331 1197 391 937 1770 244 1258 871 2433 1232 2262 2026 1156 2512
165
+ 511 2883]"
166
+ 82,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[1094 380 3078 205 139 1801 979 946 1516 1679 2162 2108 1809 328
167
+ 1665 3029]"
168
+ 83,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[ 128 347 2567 2242 80 1650 1767 448 1017 1355 770 312 2440 1303
169
+ 390 1242]"
170
+ 84,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[2045 1796 158 1262 1904 1419 1119 412 2580 2295 1962 1698 1722 33
171
+ 2802 2991]"
172
+ 85,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[1218 2734 3055 1628 952 696 453 358 1060 1586 2729 1747 2617 445
173
+ 2645 151]"
174
+ 86,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[ 995 2980 1916 1587 597 2642 2655 2056 3043 744 485 299 2857 1821
175
+ 1486 2900]"
176
+ 87,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[ 309 1254 1252 2332 2660 1066 1024 2211 1510 2672 222 1067 210 1862
177
+ 1563 2527]"
178
+ 88,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[2059 1741 2348 1133 678 1170 491 662 2035 2201 1640 755 600 2043
179
+ 1893 1505]"
180
+ 89,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[2217 2300 2425 1829 627 1039 977 1099 2681 2301 1347 1489 2119 858
181
+ 1548 1676]"
182
+ 90,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[2708 2586 108 630 580 2214 77 2447 45 2508 337 2500 103 437
183
+ 2501 1798]"
184
+ 91,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[ 750 3001 702 1532 623 1861 411 1340 988 924 1651 2343 1805 2243
185
+ 2871 2271]"
186
+ 92,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[1298 2305 3082 1574 2725 1994 783 908 2298 2490 1565 1238 2652 1051
187
+ 896 2190]"
188
+ 93,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[ 175 2055 1430 3026 2005 306 671 2416 1957 1332 1754 1693 1434 248
189
+ 2766 1915]"
190
+ 94,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[ 639 548 872 2565 2127 1995 666 2664 956 2007 504 1704 3083 2588
191
+ 582 682]"
192
+ 95,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[ 135 1464 305 1762 535 2144 964 1026 2631 3028 665 2794 355 2465
193
+ 1528 1888]"
194
+ 96,[1 1 1 1 1 1 1 1 1 1 1 1 1 0 1 1],"[2054 2048 1399 458 1820 524 378 2405 1473 547 1982 813 2574 2974
195
+ 2032 53]"
196
+ 97,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[2420 748 2011 2186 2910 2845 2969 1380 2732 2106 2876 816 2750 469
197
+ 2996 443]"
198
+ 98,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[2232 558 120 2989 2132 436 2497 1109 1742 983 330 2753 348 3041
199
+ 690 1131]"
200
+ 99,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[2028 216 706 606 2535 2475 240 1744 325 973 1958 814 2961 2206
201
+ 611 2408]"
202
+ 100,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[1126 1588 546 844 12 1111 2431 1317 922 645 153 1478 1007 1457
203
+ 2605 1187]"
204
+ 101,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[1184 483 471 40 1202 2577 255 1176 2238 2256 3039 2259 714 2401
205
+ 1247 401]"
206
+ 102,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[1515 1506 2804 2191 1533 189 2881 588 488 2566 2095 628 1257 192
207
+ 1525 900]"
208
+ 103,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[2893 2872 3019 954 1471 202 704 974 650 986 1509 2436 1993 2350
209
+ 457 797]"
210
+ 104,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[ 545 1979 1612 2474 1581 2746 685 1520 635 2898 2843 5 2502 1335
211
+ 2840 991]"
212
+ 105,[1 1 0 1 1 1 1 0 1 1 1 1 1 1 1 1],"[2170 968 573 1164 2137 2396 2423 2649 2453 1645 89 2514 2704 0
213
+ 2228 1735]"
214
+ 106,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[ 163 369 155 1246 1897 285 2552 533 2739 680 638 1561 3007 3073
215
+ 295 1908]"
216
+ 107,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[2624 1403 2477 1013 1773 1659 1692 1988 44 752 2592 350 578 1675
217
+ 1101 2135]"
218
+ 108,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[ 649 2573 1132 1655 878 1576 1287 1289 2213 164 3038 1815 1375 2373
219
+ 59 518]"
220
+ 109,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[ 561 2091 2728 2292 1446 406 114 2361 1172 636 1808 530 333 2965
221
+ 2671 1526]"
222
+ 110,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[1237 2824 55 2364 564 745 257 3035 2844 2328 1140 302 1901 1972
223
+ 2957 2257]"
224
+ 111,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[1195 2963 2193 2807 996 1732 1816 850 1551 981 1880 2478 423 2086
225
+ 2438 22]"
226
+ 112,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[1088 2634 2272 884 3013 1842 148 1538 771 1500 2590 463 1824 2602
227
+ 1228 1894]"
228
+ 113,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[ 258 279 541 1466 1899 188 1211 2522 1604 1145 1042 2081 2544 2278
229
+ 152 1660]"
230
+ 114,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[2519 18 947 1124 3027 207 1877 474 1825 2665 1044 476 948 320
231
+ 2543 775]"
232
+ 115,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[1322 1266 2375 2031 1362 183 154 1648 2599 1225 731 660 1068 2869
233
+ 998 1056]"
234
+ 116,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[2727 3069 2103 1889 1452 2935 2979 661 1618 2318 921 2283 855 1312
235
+ 738 955]"
236
+ 117,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[1669 1954 104 2233 402 778 3017 1504 1813 2882 459 264 2715 2856
237
+ 1158 2682]"
238
+ 118,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[2986 852 1151 1423 1025 1221 1203 230 1971 1595 1227 2212 909 1070
239
+ 1724 2069]"
240
+ 119,[1 1 0 1 1 1 1 0 1 1 1 1 1 1 1 1],"[ 234 180 930 2730 2879 2448 2168 2740 508 1677 2764 1271 1952 1592
241
+ 272 1346]"
242
+ 120,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[1795 1555 2684 494 723 2109 1518 2629 1200 520 1405 1295 1737 2863
243
+ 2462 2894]"
244
+ 121,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[ 366 90 1174 2037 1297 978 2285 914 1256 1369 1917 206 697 3042
245
+ 1112 2215]"
246
+ 122,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[2771 1752 2748 2853 2199 2087 3033 2006 853 162 2451 2264 2460 1408
247
+ 587 1580]"
248
+ 123,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[2185 2249 280 2534 1970 334 2123 1546 2923 2719 265 2758 1293 1236
249
+ 233 2251]"
250
+ 124,[1 1 0 1 1 1 1 1 0 1 1 1 1 1 1 1],"[ 879 915 2688 810 1129 42 1050 765 1854 28 2064 373 2175 388
251
+ 276 2640]"
252
+ 125,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[2658 1779 34 1807 1069 1736 1924 1440 2622 3070 1334 1714 38 1725
253
+ 204 2347]"
254
+ 126,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[1392 1502 2414 1637 1167 100 3071 399 1005 2041 571 1911 2897 289
255
+ 2439 2506]"
256
+ 127,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[2918 1879 318 870 1458 1463 2858 629 1354 1923 903 825 1817 35
257
+ 2859 2450]"
258
+ 128,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[1045 288 1759 976 1028 1409 338 1603 17 1706 3011 1865 1215 300
259
+ 447 727]"
260
+ 129,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[ 490 208 2252 851 960 479 647 1749 2909 2860 965 2823 2943 439
261
+ 2829 2560]"
262
+ 130,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[ 167 2808 2679 1826 1812 929 2632 1784 1621 550 897 562 2651 2141
263
+ 1150 3030]"
264
+ 131,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[2178 1065 2546 1103 2861 655 1224 764 2934 1751 1746 864 186 405
265
+ 507 1544]"
266
+ 132,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[ 219 2562 875 1542 1208 609 274 2757 170 2345 1207 2464 1764 828
267
+ 970 2712]"
268
+ 133,[1 1 0 1 1 1 1 1 1 0 1 1 1 0 1 1],"[ 157 84 618 672 221 1243 2014 303 2743 1605 1883 1089 266 1023
269
+ 2837 2670]"
270
+ 134,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[1114 906 2239 2418 1755 1376 287 1708 2538 270 563 826 713 1978
271
+ 283 1178]"
272
+ 135,[1 1 0 1 1 1 1 0 1 1 1 1 1 1 1 1],"[2585 166 2157 2482 549 555 1325 1341 2786 2785 2547 92 416 359
273
+ 1539 1898]"
274
+ 136,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[ 926 824 1255 2476 2928 1300 2873 1077 827 236 3031 91 1037 2902
275
+ 1776 2082]"
276
+ 137,[1 1 0 1 1 1 1 1 0 0 1 1 1 0 1 1],"[ 793 2768 1734 3036 1194 1758 16 776 2311 1206 2762 171 1106 1098
277
+ 1872 1361]"
278
+ 138,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[2230 1251 1743 2694 440 2023 2134 1517 433 2455 659 522 1087 1142
279
+ 2183 2931]"
280
+ 139,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[1540 1217 2346 131 2383 1856 480 486 2697 2114 1495 2449 2080 1498
281
+ 141 1961]"
282
+ 140,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[2801 2695 2515 1870 121 1717 58 2222 1128 181 1410 615 2034 1810
283
+ 3061 1641]"
284
+ 141,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[ 197 1718 2781 1584 794 1983 2499 2494 1481 169 499 2078 527 2518
285
+ 2148 56]"
286
+ 142,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[2220 1385 2677 470 1623 1720 835 2796 2834 1183 2194 1073 2999 2763
287
+ 2973 495]"
288
+ 143,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[ 477 1416 585 2793 2800 528 1996 2827 1110 2698 646 687 1115 2696
289
+ 668 2765]"
290
+ 144,[1 1 0 1 1 1 0 1 1 1 1 1 1 1 1 1],"[2803 2050 2774 538 2956 2638 741 3072 2503 3060 290 1161 1616 2772
291
+ 2274 542]"
292
+ 145,[1 1 0 1 1 1 1 1 1 1 0 1 1 1 1 1],"[ 427 2536 49 1327 807 506 466 147 1688 211 2399 3004 801 1590
293
+ 48 1049]"
294
+ 146,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[ 118 99 2419 1552 2296 2088 2888 239 3057 106 820 2841 935 2942
295
+ 1171 669]"
296
+ 147,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[ 940 2234 2203 2297 1863 1966 2647 1622 1349 2456 2915 105 1436 2147
297
+ 1999 1333]"
298
+ 148,[1 1 1 0 1 1 1 1 1 1 1 1 1 1 1 1],"[1987 179 73 235 1162 1942 509 2906 1121 414 861 2390 98 3016
299
+ 773 1048]"
300
+ 149,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[2724 251 1887 500 195 316 2593 2691 268 2937 1869 1027 1296 1138
301
+ 2485 2916]"
302
+ 150,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[1393 1443 2601 1753 61 1270 2608 2349 14 819 130 2481 2200 2188
303
+ 1231 2154]"
304
+ 151,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[2654 729 742 150 1390 110 911 1003 2372 2319 658 2377 1933 1596
305
+ 534 1625]"
306
+ 152,[1 1 0 1 1 1 1 1 0 1 1 1 1 1 1 1],"[ 496 2885 552 691 703 2258 2340 1956 536 1134 323 452 920 2912
307
+ 2351 1052]"
308
+ 153,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[2022 2426 2441 1582 353 1415 446 2775 2563 2065 291 1076 2920 2090
309
+ 403 622]"
310
+ 154,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[1929 1512 3032 1447 2075 2820 2702 2317 916 2072 505 1467 1480 2673
311
+ 551 367]"
312
+ 155,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[ 271 138 450 2356 1633 1678 343 718 2784 2867 2899 2735 1828 699
313
+ 1401 1483]"
314
+ 156,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[1846 2336 473 1959 987 417 1096 577 2229 60 2690 160 894 246
315
+ 2071 3010]"
316
+ 157,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[2270 1653 1557 260 43 24 2412 1338 1196 1885 2914 1343 2070 2510
317
+ 1484 898]"
318
+ 158,[1 1 0 0 1 1 1 1 1 1 1 1 1 1 1 1],"[1097 2016 2018 1344 869 2017 335 1326 1313 2855 2595 557 2386 756
319
+ 382 2167]"
320
+ 159,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[2846 654 2247 2384 1090 950 516 1191 2745 2241 1394 1647 1427 2443
321
+ 2266 1360]"
322
+ 160,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[2027 3034 156 203 2826 882 2160 2176 1074 1248 1932 187 2821 2326
323
+ 161 612]"
324
+ 161,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[1031 2368 1122 2659 1139 1475 1062 2415 1250 2407 2389 1177 2223 2576
325
+ 102 675]"
326
+ 162,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[2523 3014 2139 1508 1319 1160 393 2619 3046 1381 760 69 438 1832
327
+ 2158 1011]"
328
+ 163,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[1847 2150 1472 1426 2516 715 1377 1668 1619 732 3044 2550 2010 1905
329
+ 1667 982]"
330
+ 164,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[ 519 782 892 2921 1308 993 1468 2110 1185 1712 1556 2932 2596 225
331
+ 923 2532]"
332
+ 165,[1 1 0 1 1 1 1 0 1 1 1 1 0 0 1 1],"[1947 2315 1827 1263 2457 1292 1960 363 1857 365 2287 2569 2674 2922
333
+ 523 2261]"
334
+ 166,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[2895 2615 1705 1951 1703 46 719 10 228 934 1462 409 2434 2575
335
+ 759 482]"
336
+ 167,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[2231 881 1398 1906 791 1391 1212 1414 54 2737 1902 502 2656 1836
337
+ 2277 730]"
338
+ 168,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1],"[ 688 941 1766 2504 2417 2269 1859 26 3080 1976 1771 2458 1014 2195
339
+ 361 2079]"
340
+ 169,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[2612 317 116 431 87 2614 1639 2598 498 177 2077 1260 568 2770
341
+ 394 119]"
342
+ 170,[1 1 0 0 1 1 1 1 1 1 1 1 1 0 1 1],"[1411 2525 1146 2329 754 70 1280 2701 1435 1878 1493 1819 1843 2393
343
+ 2487 1793]"
344
+ 171,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[2107 124 2366 1272 2633 2744 2142 2838 1634 1002 2013 2572 1084 2825
345
+ 999 2992]"
346
+ 172,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[ 460 805 972 1444 2015 2661 2878 2216 2101 81 2759 2309 23 176
347
+ 381 461]"
348
+ 173,[1 1 0 1 1 1 1 0 1 1 1 1 1 0 1 1],"[1818 1910 229 314 383 1583 1850 478 435 2480 1199 598 226 2024
349
+ 1033 2944]"
350
+ 174,[1 1 0 0 1 1 1 1 1 1 1 1 1 1 1 1],"[1330 1310 1778 293 1469 2411 408 2643 2117 2533 1578 2358 2124 1015
351
+ 2219 2307]"
352
+ 175,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[ 376 2513 917 1036 1311 701 357 3076 2316 2971 874 27 607 2146
353
+ 2791 3050]"
354
+ 176,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[ 980 1465 1802 3025 2493 1294 3006 2949 127 670 2466 860 410 7
355
+ 1783 1800]"
356
+ 177,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[ 191 673 464 1053 2551 642 2308 586 936 2531 1268 278 843 2172
357
+ 1792 3008]"
358
+ 178,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[ 652 2988 2714 241 1835 1460 1006 2391 165 2484 2607 331 200 1945
359
+ 812 68]"
360
+ 179,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[1105 434 2505 284 1350 841 2410 1553 2180 2810 319 1903 2731 1496
361
+ 1721 743]"
362
+ 180,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[1627 2039 2338 2998 1664 2751 569 1874 1244 910 1413 2776 356 2097
363
+ 88 1092]"
364
+ 181,[1 1 1 1 1 1 1 1 1 1 1 1 1 0 1 1],"[2404 1969 1388 2289 1373 1975 2435 2151 2104 97 992 1163 1646 961
365
+ 1965 2975]"
366
+ 182,[1 1 0 1 1 1 1 1 1 1 1 1 0 0 1 1],"[2245 454 667 2365 2098 247 1020 497 1487 1569 1412 3059 2400 1328
367
+ 1513 1043]"
368
+ 183,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[2244 883 1763 269 145 441 3065 2115 122 1560 1873 2985 1570 2706
369
+ 695 865]"
370
+ 184,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[ 637 1761 2721 1702 3068 396 2291 1302 2950 1695 1454 683 1830 1928
371
+ 2680 2561]"
372
+ 185,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[2036 2818 1914 2908 603 1967 2832 1093 185 398 242 1935 769 969
373
+ 945 2889]"
374
+ 186,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[2409 739 3067 174 907 1188 2926 1032 2235 1259 79 101 339 47
375
+ 2779 656]"
376
+ 187,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[ 949 3056 1980 1063 1694 2324 2303 1059 2461 2284 1666 513 1727 601
377
+ 2029 1456]"
378
+ 188,[1 1 0 1 1 1 1 1 1 1 0 1 1 0 1 1],"[2290 1356 583 1635 2166 2584 1992 1602 1643 2009 1794 1275 624 1501
379
+ 250 201]"
380
+ 189,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[2852 2051 1342 501 1937 2752 2068 1095 1367 2367 747 2313 931 149
381
+ 2997 492]"
382
+ 190,[1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1],"[2877 198 85 2669 1852 614 2733 3054 362 784 3020 2096 1016 2983
383
+ 572 134]"
384
+ 191,[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1],"[1366 109 2905 1629 375 2686 1922 1315 912 2537 1739 281 1938 2874
385
+ 2747 1135]"
src/test_piqa.py CHANGED
@@ -24,7 +24,7 @@ def preprocess(example):
24
 
25
  test_dataset=dataset['test'].map(preprocess)
26
 
27
- len_test_dataset=100
28
 
29
  test_dataset=test_dataset.select(range(len_test_dataset))
30
 
@@ -65,7 +65,7 @@ total_batch_size=16
65
 
66
  from model_file import FlaxGPTNeoForMultipleChoice
67
 
68
- model = FlaxGPTNeoForMultipleChoice.from_pretrained('Vivek/gptneo_hellaswag',input_shape=(1,num_choices,1))
69
 
70
  restored_output=[]
71
  rng, input_rng = jax.random.split(rng)
@@ -75,4 +75,4 @@ for idx,batch in enumerate(glue_test_data_loader(input_rng, test_dataset, total_
75
  restored_output.append(final_output)
76
 
77
  finall=pd.DataFrame({'predictions':restored_output,'permutation':list1})
78
- finall.to_csv('./piqa_predictions.csv')
 
24
 
25
  test_dataset=dataset['test'].map(preprocess)
26
 
27
+ len_test_dataset=3084
28
 
29
  test_dataset=test_dataset.select(range(len_test_dataset))
30
 
 
65
 
66
  from model_file import FlaxGPTNeoForMultipleChoice
67
 
68
+ model = FlaxGPTNeoForMultipleChoice.from_pretrained('Vivek/gptneo_piqa',input_shape=(1,num_choices,1))
69
 
70
  restored_output=[]
71
  rng, input_rng = jax.random.split(rng)
 
75
  restored_output.append(final_output)
76
 
77
  finall=pd.DataFrame({'predictions':restored_output,'permutation':list1})
78
+ finall.to_csv('./piqa_predictions.csv')