|
"""Implementation of the Uni-mol+ model with alterations to the original model.""" |
|
|
|
from typing import Any, Optional, Tuple |
|
|
|
import torch |
|
import torch.nn.functional as f |
|
from torch import nn |
|
from transformers.modeling_utils import PreTrainedModel |
|
from .configuration_atomformer import AtomformerConfig |
|
|
|
|
|
ATOM_METADATA = [ |
|
[ |
|
0.0, |
|
0.0, |
|
0.0, |
|
0.0, |
|
0.0, |
|
0.0, |
|
0.0, |
|
0.106761565836299, |
|
0.4573170731707318, |
|
0.46896368424867707, |
|
0.0, |
|
0.0, |
|
0.0027383806383189145, |
|
0.0, |
|
1.0, |
|
0.0, |
|
0.0, |
|
], |
|
[ |
|
0.008547008547008548, |
|
0.010187317385107808, |
|
0.011235955056179775, |
|
0.008547008547008548, |
|
0.008547008547008548, |
|
0.0, |
|
1.0, |
|
0.0, |
|
-1.0, |
|
0.9999999999999999, |
|
2.1731754967921256e-06, |
|
-1.0, |
|
0.0, |
|
0.010000000000000002, |
|
0.3588318085855031, |
|
0.0, |
|
-1.0, |
|
], |
|
[ |
|
0.017094017094017096, |
|
0.02018415404448405, |
|
0.02247191011235955, |
|
0.017094017094017096, |
|
0.017094017094017096, |
|
0.16666666666666666, |
|
0.0, |
|
0.5729537366548044, |
|
0.08536585365853658, |
|
0.0723802160098582, |
|
0.01302222611458848, |
|
0.1117635470484688, |
|
0.2746530986669577, |
|
0.010000000000000002, |
|
0.2454609429978888, |
|
0.16666666666666666, |
|
0.0, |
|
], |
|
[ |
|
0.025641025641025644, |
|
0.027228539455021038, |
|
0.028089887640449437, |
|
0.025641025641025644, |
|
0.025641025641025644, |
|
0.16666666666666666, |
|
0.058823529411764705, |
|
0.32384341637010683, |
|
0.2652439024390244, |
|
0.2623432478797689, |
|
0.0451198574701265, |
|
0.39298038243761085, |
|
0.4668171696125004, |
|
0.015, |
|
0.12181562280084446, |
|
0.16666666666666666, |
|
0.14285714285714285, |
|
], |
|
[ |
|
0.03418803418803419, |
|
0.03334773276914757, |
|
0.033707865168539325, |
|
0.03418803418803419, |
|
0.03418803418803419, |
|
0.16666666666666666, |
|
0.7058823529411764, |
|
0.25266903914590755, |
|
0.4085365853658537, |
|
0.2128252833015198, |
|
0.057071103187614054, |
|
0.6504807478441018, |
|
0.715419845245687, |
|
0.015, |
|
0.06558761435608726, |
|
0.16666666666666666, |
|
0.2857142857142857, |
|
], |
|
[ |
|
0.042735042735042736, |
|
0.03742946260625253, |
|
0.033707865168539325, |
|
0.042735042735042736, |
|
0.042735042735042736, |
|
0.16666666666666666, |
|
0.7647058823529411, |
|
0.14946619217081855, |
|
0.5640243902439024, |
|
0.3559765143644139, |
|
0.055363782370830124, |
|
1.0000000000000002, |
|
0.7324707832177849, |
|
0.020000000000000004, |
|
0.04327938071780436, |
|
0.16666666666666666, |
|
0.42857142857142855, |
|
], |
|
[ |
|
0.051282051282051294, |
|
0.04421873990197045, |
|
0.03932584269662921, |
|
0.051282051282051294, |
|
0.051282051282051294, |
|
0.16666666666666666, |
|
0.8235294117647058, |
|
0.09252669039145908, |
|
0.7134146341463414, |
|
0.514180781404789, |
|
2.8295183993586364e-05, |
|
0.012484827687008686, |
|
0.012471056032792366, |
|
0.025, |
|
0.06657283603096412, |
|
0.16666666666666666, |
|
0.5714285714285714, |
|
], |
|
[ |
|
0.05982905982905984, |
|
0.050994411431564704, |
|
0.0449438202247191, |
|
0.05982905982905984, |
|
0.05982905982905984, |
|
0.16666666666666666, |
|
0.8823529411764706, |
|
0.05693950177935947, |
|
0.8353658536585367, |
|
0.4699156740039142, |
|
3.268543752245935e-05, |
|
0.00923366315240946, |
|
0.014660396468409729, |
|
0.025, |
|
0.057987332864180154, |
|
0.16666666666666666, |
|
0.7142857142857142, |
|
], |
|
[ |
|
0.06837606837606838, |
|
0.06119533458279619, |
|
0.056179775280898875, |
|
0.06837606837606838, |
|
0.06837606837606838, |
|
0.16666666666666666, |
|
0.9411764705882353, |
|
0.028469750889679707, |
|
1.0, |
|
0.6537753400826345, |
|
3.9270817815768815e-05, |
|
0.01002929606822616, |
|
0.01377886297525227, |
|
0.015, |
|
0.051372273047149884, |
|
0.16666666666666666, |
|
0.8571428571428572, |
|
], |
|
[ |
|
0.07692307692307693, |
|
0.06521583847234458, |
|
0.056179775280898875, |
|
0.07692307692307693, |
|
0.07692307692307693, |
|
0.16666666666666666, |
|
1.0, |
|
0.007117437722419961, |
|
-1.0, |
|
0.8539203131418077, |
|
1.9758579909666677e-05, |
|
0.002676173590325307, |
|
0.003896139326624358, |
|
0.025, |
|
0.06586910626319493, |
|
0.16666666666666666, |
|
1.0, |
|
], |
|
[ |
|
0.08547008547008549, |
|
0.07477388917423204, |
|
0.06741573033707865, |
|
0.08547008547008549, |
|
0.08547008547008549, |
|
0.33333333333333337, |
|
0.0, |
|
0.6085409252669042, |
|
0.07012195121951223, |
|
0.06017348442747725, |
|
0.023680786070796774, |
|
0.09074155275516495, |
|
0.19638929337502856, |
|
0.020000000000000004, |
|
0.07980295566502463, |
|
0.33333333333333337, |
|
0.0, |
|
], |
|
[ |
|
0.09401709401709403, |
|
0.0792467847873929, |
|
0.06741573033707865, |
|
0.09401709401709403, |
|
0.09401709401709403, |
|
0.33333333333333337, |
|
0.058823529411764705, |
|
0.43060498220640586, |
|
0.1859756097560976, |
|
0.18132746997849566, |
|
0.04243692475803745, |
|
0.23105764525702377, |
|
0.2316847349772711, |
|
0.025, |
|
0.0653764954257565, |
|
0.33333333333333337, |
|
0.14285714285714285, |
|
], |
|
[ |
|
0.10256410256410257, |
|
0.08835244376566789, |
|
0.07865168539325842, |
|
0.10256410256410257, |
|
0.10256410256410257, |
|
0.33333333333333337, |
|
0.7058823529411764, |
|
0.4661921708185055, |
|
0.2774390243902439, |
|
0.10108971416145165, |
|
0.06585161024536003, |
|
0.23366315240945865, |
|
0.4753426385985493, |
|
0.025, |
|
0.05650950035186488, |
|
0.33333333333333337, |
|
0.2857142857142857, |
|
], |
|
[ |
|
0.11111111111111113, |
|
0.09210763521580445, |
|
0.07865168539325842, |
|
0.11111111111111113, |
|
0.11111111111111113, |
|
0.33333333333333337, |
|
0.7647058823529411, |
|
0.35943060498220647, |
|
0.36585365853658536, |
|
0.20575543044917485, |
|
0.05682720021378778, |
|
0.4242464682668294, |
|
0.6025426358703994, |
|
0.025, |
|
0.04299788881069668, |
|
0.33333333333333337, |
|
0.42857142857142855, |
|
], |
|
[ |
|
0.11965811965811968, |
|
0.10193099835710374, |
|
0.0898876404494382, |
|
0.11965811965811968, |
|
0.11965811965811968, |
|
0.33333333333333337, |
|
0.8235294117647058, |
|
0.25266903914590755, |
|
0.4542682926829268, |
|
0.3185927948389591, |
|
0.04438814854864767, |
|
0.07704039807065373, |
|
0.09357213740327856, |
|
0.020000000000000004, |
|
0.04750175932441942, |
|
0.33333333333333337, |
|
0.5714285714285714, |
|
], |
|
[ |
|
0.12820512820512822, |
|
0.10564197106733833, |
|
0.0898876404494382, |
|
0.12820512820512822, |
|
0.12820512820512822, |
|
0.33333333333333337, |
|
0.8823529411764706, |
|
0.21708185053380794, |
|
0.5731707317073171, |
|
0.31247009930654546, |
|
0.05048572289430458, |
|
0.09515439218602051, |
|
0.1216720831812958, |
|
0.035, |
|
0.04334975369458127, |
|
0.33333333333333337, |
|
0.7142857142857142, |
|
], |
|
[ |
|
0.13675213675213677, |
|
0.11716605497409803, |
|
0.10112359550561797, |
|
0.13675213675213677, |
|
0.13675213675213677, |
|
0.33333333333333337, |
|
0.9411764705882353, |
|
0.17081850533807832, |
|
0.75, |
|
0.43848068233986515, |
|
7.610016686353661e-05, |
|
0.04019725595612581, |
|
0.040050948202660634, |
|
0.04, |
|
0.0270935960591133, |
|
0.33333333333333337, |
|
0.8571428571428572, |
|
], |
|
[ |
|
0.1452991452991453, |
|
0.13245553465558702, |
|
0.12359550561797752, |
|
0.1452991452991453, |
|
0.1452991452991453, |
|
0.33333333333333337, |
|
1.0, |
|
0.13879003558718864, |
|
-1.0, |
|
0.573402276077029, |
|
4.1222041606379033e-05, |
|
0.0177390552812359, |
|
0.01416591926721889, |
|
0.025, |
|
0.029978888106966927, |
|
0.33333333333333337, |
|
1.0, |
|
], |
|
[ |
|
0.15384615384615385, |
|
0.12956430935430432, |
|
0.11235955056179775, |
|
0.15384615384615385, |
|
0.15384615384615385, |
|
0.5, |
|
0.0, |
|
0.8220640569395019, |
|
0.036585365853658514, |
|
0.021591320946190845, |
|
0.02102224365609036, |
|
0.08193366760083631, |
|
0.17524613028962724, |
|
0.035, |
|
0.04665728360309641, |
|
0.5, |
|
0.0, |
|
], |
|
[ |
|
0.1623931623931624, |
|
0.13289772205460673, |
|
0.11235955056179775, |
|
0.1623931623931624, |
|
0.1623931623931624, |
|
0.5, |
|
0.058823529411764705, |
|
0.6085409252669042, |
|
0.09146341463414634, |
|
0.10724623674100561, |
|
0.03755886528151192, |
|
0.27910065518972543, |
|
0.2988654305873366, |
|
0.05500000000000001, |
|
0.038916256157635463, |
|
0.5, |
|
0.14285714285714285, |
|
], |
|
[ |
|
0.17094017094017097, |
|
0.14948995384243843, |
|
0.1348314606741573, |
|
0.17094017094017097, |
|
0.17094017094017097, |
|
0.5, |
|
0.1176470588235294, |
|
0.5729537366548044, |
|
0.201219512195122, |
|
0.128910044216783, |
|
0.07292479648632205, |
|
0.4570377290145464, |
|
0.5293941119700996, |
|
0.06, |
|
0.03335679099225897, |
|
0.5, |
|
-1.0, |
|
], |
|
[ |
|
0.17948717948717952, |
|
0.15939155013894887, |
|
0.14606741573033707, |
|
0.17948717948717952, |
|
0.17948717948717952, |
|
0.5, |
|
0.1764705882352941, |
|
0.5373665480427048, |
|
0.25609756097560976, |
|
0.14179331674197213, |
|
0.11072975742939495, |
|
0.48779542320426544, |
|
0.6062938422242609, |
|
0.03, |
|
0.03019000703729768, |
|
0.5, |
|
-1.0, |
|
], |
|
[ |
|
0.18803418803418806, |
|
0.16985098284653036, |
|
0.15730337078651685, |
|
0.18803418803418806, |
|
0.18803418803418806, |
|
0.5, |
|
0.23529411764705882, |
|
0.5017793594306051, |
|
0.2835365853658536, |
|
0.13783555222654456, |
|
0.14902252432012042, |
|
0.5493108115837035, |
|
0.6267549677907782, |
|
0.03, |
|
0.027797325826882473, |
|
0.5, |
|
-1.0, |
|
], |
|
[ |
|
0.1965811965811966, |
|
0.17343610222012087, |
|
0.15730337078651685, |
|
0.1965811965811966, |
|
0.1965811965811966, |
|
0.5, |
|
0.2941176470588235, |
|
0.5017793594306051, |
|
0.29268292682926833, |
|
0.13881653659361634, |
|
0.1743884335980532, |
|
0.5378719996949651, |
|
0.5012600643161381, |
|
0.03, |
|
0.024982406755805767, |
|
0.5, |
|
-1.0, |
|
], |
|
[ |
|
0.20512820512820515, |
|
0.1834431432040899, |
|
0.16853932584269662, |
|
0.20512820512820515, |
|
0.20512820512820515, |
|
0.5, |
|
0.3529411764705882, |
|
0.4661921708185055, |
|
0.25914634146341464, |
|
0.17107304225964678, |
|
0.1814616198390152, |
|
0.38255835382787134, |
|
0.3972493426863412, |
|
0.04, |
|
0.0270935960591133, |
|
0.5, |
|
-1.0, |
|
], |
|
[ |
|
0.2136752136752137, |
|
0.18652825067263504, |
|
0.16853932584269662, |
|
0.2136752136752137, |
|
0.2136752136752137, |
|
0.5, |
|
0.4117647058823529, |
|
0.43060498220640586, |
|
0.3445121951219513, |
|
0.19370816923188441, |
|
0.1919494477135451, |
|
0.4560209457355474, |
|
0.5336568464631241, |
|
0.035, |
|
0.024982406755805767, |
|
0.5, |
|
-1.0, |
|
], |
|
[ |
|
0.22222222222222224, |
|
0.19703190212011848, |
|
0.1797752808988764, |
|
0.22222222222222224, |
|
0.22222222222222224, |
|
0.5, |
|
0.47058823529411764, |
|
0.43060498220640586, |
|
0.3597560975609756, |
|
0.19267402807644915, |
|
0.2160958421223465, |
|
0.4458531129455577, |
|
0.5449104655247086, |
|
0.05500000000000001, |
|
0.023011963406052074, |
|
0.5, |
|
-1.0, |
|
], |
|
[ |
|
0.2307692307692308, |
|
0.19621555615269748, |
|
0.1741573033707865, |
|
0.2307692307692308, |
|
0.2307692307692308, |
|
0.5, |
|
0.5294117647058824, |
|
0.3950177935943062, |
|
0.36890243902439024, |
|
0.18101819411892625, |
|
0.2173153569914779, |
|
0.4351768885160684, |
|
0.5425233342086149, |
|
0.04, |
|
0.024630541871921183, |
|
0.5, |
|
-1.0, |
|
], |
|
[ |
|
0.23931623931623935, |
|
0.21272275190225615, |
|
0.19662921348314605, |
|
0.23931623931623935, |
|
0.23931623931623935, |
|
0.5, |
|
0.5882352941176471, |
|
0.3950177935943062, |
|
0.36585365853658536, |
|
0.1852030830937251, |
|
0.2185348718606093, |
|
0.3415311485202626, |
|
0.4826745419265514, |
|
0.04, |
|
0.02047853624208304, |
|
0.5, |
|
-1.0, |
|
], |
|
[ |
|
0.2478632478632479, |
|
0.2189609956699649, |
|
0.19662921348314605, |
|
0.2478632478632479, |
|
0.2478632478632479, |
|
0.5, |
|
0.6470588235294117, |
|
0.35943060498220647, |
|
0.28963414634146345, |
|
0.2657984391233962, |
|
0.17390062765040062, |
|
0.17252397384325016, |
|
0.20048151848833204, |
|
0.06, |
|
0.020689655172413793, |
|
0.5, |
|
-1.0, |
|
], |
|
[ |
|
0.2564102564102564, |
|
0.23373345623875397, |
|
0.2191011235955056, |
|
0.2564102564102564, |
|
0.2564102564102564, |
|
0.5, |
|
0.7058823529411764, |
|
0.4661921708185055, |
|
0.33841463414634154, |
|
0.10174209292773093, |
|
0.14414446484359486, |
|
0.0733952300154424, |
|
0.4216321839864411, |
|
0.05500000000000001, |
|
0.019493314567206193, |
|
0.5, |
|
0.2857142857142857, |
|
], |
|
[ |
|
0.26495726495726496, |
|
0.24365546118444995, |
|
0.2303370786516854, |
|
0.26495726495726496, |
|
0.26495726495726496, |
|
0.5, |
|
0.7647058823529411, |
|
0.35943060498220647, |
|
0.39939024390243894, |
|
0.19356319617271125, |
|
0.12975418938784455, |
|
0.30434230009087504, |
|
0.5288825838309366, |
|
0.07, |
|
0.015904292751583393, |
|
0.5, |
|
0.42857142857142855, |
|
], |
|
[ |
|
0.27350427350427353, |
|
0.2514175507580112, |
|
0.23595505617977527, |
|
0.27350427350427353, |
|
0.27350427350427353, |
|
0.5, |
|
0.8235294117647058, |
|
0.2882562277580072, |
|
0.451219512195122, |
|
0.28485756396936235, |
|
0.1409737261838533, |
|
0.2735083471552311, |
|
0.15052227023008535, |
|
0.05500000000000001, |
|
0.016537649542575653, |
|
0.5, |
|
0.5714285714285714, |
|
], |
|
[ |
|
0.28205128205128205, |
|
0.2651525716598694, |
|
0.25280898876404495, |
|
0.28205128205128205, |
|
0.28205128205128205, |
|
0.5, |
|
0.8823529411764706, |
|
0.25266903914590755, |
|
0.5640243902439024, |
|
0.2831082223886727, |
|
0.11731513772270441, |
|
0.12200763858438347, |
|
0.1626284361902748, |
|
0.085, |
|
0.01597466572836031, |
|
0.5, |
|
0.7142857142857142, |
|
], |
|
[ |
|
0.2905982905982906, |
|
0.2683635324650587, |
|
0.25280898876404495, |
|
0.2905982905982906, |
|
0.2905982905982906, |
|
0.5, |
|
0.9411764705882353, |
|
0.21708185053380794, |
|
0.6890243902439025, |
|
0.38272404378186387, |
|
0.07609553514606364, |
|
0.06402557209946683, |
|
0.055889564484942325, |
|
0.08, |
|
0.026741731175228708, |
|
0.5, |
|
0.8571428571428572, |
|
], |
|
[ |
|
0.29914529914529914, |
|
0.2816087457864643, |
|
0.2696629213483146, |
|
0.29914529914529914, |
|
0.29914529914529914, |
|
0.5, |
|
1.0, |
|
0.18149466192170824, |
|
-1.0, |
|
0.48835141469543564, |
|
8.878312150250298e-05, |
|
0.025865695638635226, |
|
0.019729640327514418, |
|
0.1, |
|
0.010837438423645322, |
|
0.5, |
|
1.0, |
|
], |
|
[ |
|
0.3076923076923077, |
|
0.28728915314310205, |
|
0.2696629213483146, |
|
0.3076923076923077, |
|
0.3076923076923077, |
|
0.6666666666666666, |
|
0.0, |
|
0.8932384341637012, |
|
0.036585365853658514, |
|
0.013685456785947292, |
|
0.03731496230768564, |
|
0.07590668471456988, |
|
0.16313996432943775, |
|
0.085, |
|
0.01893033075299085, |
|
0.6666666666666666, |
|
0.0, |
|
], |
|
[ |
|
0.3162393162393162, |
|
0.2946090553176436, |
|
0.2808988764044944, |
|
0.3162393162393162, |
|
0.3162393162393162, |
|
0.6666666666666666, |
|
0.058823529411764705, |
|
0.7153024911032031, |
|
0.07621951219512194, |
|
0.08703215985695992, |
|
0.06438819240240236, |
|
0.26130694780724334, |
|
0.2814734738557968, |
|
0.075, |
|
0.014567206192821956, |
|
0.6666666666666666, |
|
0.14285714285714285, |
|
], |
|
[ |
|
0.3247863247863248, |
|
0.29898330912640775, |
|
0.2808988764044944, |
|
0.3247863247863248, |
|
0.3247863247863248, |
|
0.6666666666666666, |
|
0.1176470588235294, |
|
0.6441281138790037, |
|
0.15853658536585366, |
|
0.11227680189431463, |
|
0.109022436612611, |
|
0.4537331833577997, |
|
0.6146488018305888, |
|
0.09, |
|
0.014356087262491202, |
|
0.6666666666666666, |
|
-1.0, |
|
], |
|
[ |
|
0.3333333333333333, |
|
0.3068678505950822, |
|
0.28651685393258425, |
|
0.3333333333333333, |
|
0.3333333333333333, |
|
0.6666666666666666, |
|
0.1764705882352941, |
|
0.6085409252669042, |
|
0.19207317073170735, |
|
0.1324087273781622, |
|
0.15877864327317145, |
|
0.5366010205962164, |
|
0.7976053662711987, |
|
0.085, |
|
0.012948627726952853, |
|
0.6666666666666666, |
|
-1.0, |
|
], |
|
[ |
|
0.3418803418803419, |
|
0.312589075250091, |
|
0.29213483146067415, |
|
0.3418803418803419, |
|
0.3418803418803419, |
|
0.6666666666666666, |
|
0.23529411764705882, |
|
0.5729537366548044, |
|
0.27439024390243905, |
|
0.13844927151037764, |
|
0.2090226558813845, |
|
0.6931856455620589, |
|
0.8547260084777264, |
|
0.105, |
|
0.012033779028852921, |
|
0.6666666666666666, |
|
-1.0, |
|
], |
|
[ |
|
0.35042735042735046, |
|
0.3229770776855231, |
|
0.3033707865168539, |
|
0.35042735042735046, |
|
0.35042735042735046, |
|
0.6666666666666666, |
|
0.2941176470588235, |
|
0.5373665480427048, |
|
0.44512195121951226, |
|
0.15456544325512842, |
|
0.24877884061506758, |
|
0.7310608227047707, |
|
0.8368225236070237, |
|
0.085, |
|
0.011048557353976075, |
|
0.6666666666666666, |
|
-1.0, |
|
], |
|
[ |
|
0.358974358974359, |
|
0.3299160184086015, |
|
0.3089887640449438, |
|
0.358974358974359, |
|
0.358974358974359, |
|
0.6666666666666666, |
|
0.3529411764705882, |
|
0.5373665480427048, |
|
0.36585365853658536, |
|
0.16363109188875738, |
|
0.28048622721248356, |
|
0.6250611658691273, |
|
0.8774037559806166, |
|
0.1, |
|
-1.0, |
|
0.6666666666666666, |
|
-1.0, |
|
], |
|
[ |
|
0.36752136752136755, |
|
0.3403584439085284, |
|
0.3202247191011236, |
|
0.36752136752136755, |
|
0.36752136752136755, |
|
0.6666666666666666, |
|
0.4117647058823529, |
|
0.5017793594306051, |
|
0.4573170731707318, |
|
0.16752120230990405, |
|
0.30243749485684845, |
|
0.6377709568566146, |
|
0.7534434369234653, |
|
0.065, |
|
0.010133708655876143, |
|
0.6666666666666666, |
|
-1.0, |
|
], |
|
[ |
|
0.37606837606837606, |
|
0.346603490559299, |
|
0.3258426966292135, |
|
0.37606837606837606, |
|
0.37606837606837606, |
|
0.6666666666666666, |
|
0.47058823529411764, |
|
0.4661921708185055, |
|
0.4817073170731707, |
|
0.17227631865078405, |
|
0.30243749485684845, |
|
0.5655793440476872, |
|
0.67586166915042, |
|
0.085, |
|
0.01048557353976073, |
|
0.6666666666666666, |
|
-1.0, |
|
], |
|
[ |
|
0.38461538461538464, |
|
0.35855615609895475, |
|
0.33707865168539325, |
|
0.38461538461538464, |
|
0.38461538461538464, |
|
0.6666666666666666, |
|
0.5294117647058824, |
|
0.4661921708185055, |
|
0.4573170731707318, |
|
0.21470510063546525, |
|
0.2926813759037974, |
|
0.4603422746712931, |
|
0.5510488031946638, |
|
0.09, |
|
0.01055594651653765, |
|
0.6666666666666666, |
|
-1.0, |
|
], |
|
[ |
|
0.39316239316239315, |
|
0.363481443435728, |
|
0.34269662921348315, |
|
0.39316239316239315, |
|
0.39316239316239315, |
|
0.6666666666666666, |
|
0.5882352941176471, |
|
0.4661921708185055, |
|
0.375, |
|
0.17794476526445505, |
|
0.25609592982985585, |
|
0.31011254519919423, |
|
0.41447079003816, |
|
0.12000000000000001, |
|
0.00992258972554539, |
|
0.6666666666666666, |
|
-1.0, |
|
], |
|
[ |
|
0.4017094017094017, |
|
0.37893419231070125, |
|
0.3595505617977528, |
|
0.4017094017094017, |
|
0.4017094017094017, |
|
0.6666666666666666, |
|
0.6470588235294117, |
|
0.43060498220640586, |
|
0.301829268292683, |
|
0.24644936815908378, |
|
0.21194949156729978, |
|
0.1474729758069129, |
|
0.17661020532739505, |
|
0.095, |
|
0.00971147079521464, |
|
0.6666666666666666, |
|
-1.0, |
|
], |
|
[ |
|
0.4102564102564103, |
|
0.38712146207562764, |
|
0.3707865168539326, |
|
0.4102564102564103, |
|
0.4102564102564103, |
|
0.6666666666666666, |
|
0.7058823529411764, |
|
0.5373665480427048, |
|
0.3292682926829269, |
|
0.09145383816174166, |
|
0.1782908811792736, |
|
0.10567809912365993, |
|
0.39912494586327196, |
|
0.15500000000000003, |
|
0.009781843771991556, |
|
0.6666666666666666, |
|
0.2857142857142857, |
|
], |
|
[ |
|
0.4188034188034188, |
|
0.40035987251397137, |
|
0.38764044943820225, |
|
0.4188034188034188, |
|
0.4188034188034188, |
|
0.6666666666666666, |
|
0.7647058823529411, |
|
0.43060498220640586, |
|
0.38414634146341464, |
|
0.16671901804914588, |
|
0.17780307523162106, |
|
0.12481904435081566, |
|
0.4894949171153905, |
|
0.125, |
|
0.009429978888106968, |
|
0.6666666666666666, |
|
0.42857142857142855, |
|
], |
|
[ |
|
0.4273504273504274, |
|
0.4107342691832799, |
|
0.398876404494382, |
|
0.4273504273504274, |
|
0.4273504273504274, |
|
0.6666666666666666, |
|
0.8235294117647058, |
|
0.35943060498220647, |
|
0.41158536585365846, |
|
0.22782516249063714, |
|
0.16316889680204447, |
|
0.22620250509980366, |
|
0.3164278966985974, |
|
0.13, |
|
0.007952146375791697, |
|
0.6666666666666666, |
|
0.5714285714285714, |
|
], |
|
[ |
|
0.4358974358974359, |
|
0.43059868772385734, |
|
0.42696629213483145, |
|
0.4358974358974359, |
|
0.4358974358974359, |
|
0.6666666666666666, |
|
0.8823529411764706, |
|
0.32384341637010683, |
|
0.426829268292683, |
|
0.24721289293739582, |
|
0.15194936000603573, |
|
0.1801295127701625, |
|
0.21429277824573129, |
|
0.13, |
|
0.007600281491907108, |
|
0.6666666666666666, |
|
0.7142857142857142, |
|
], |
|
[ |
|
0.4444444444444445, |
|
0.42823128441833647, |
|
0.4157303370786517, |
|
0.4444444444444445, |
|
0.4444444444444445, |
|
0.6666666666666666, |
|
0.9411764705882353, |
|
0.2882562277580072, |
|
0.5975609756097562, |
|
0.31688211274071565, |
|
0.12024197340861972, |
|
0.09468158796128598, |
|
0.07727144070195302, |
|
0.105, |
|
0.008444757213230118, |
|
0.6666666666666666, |
|
0.8571428571428572, |
|
], |
|
[ |
|
0.452991452991453, |
|
0.4431602112975479, |
|
0.43258426966292135, |
|
0.452991452991453, |
|
0.452991452991453, |
|
0.6666666666666666, |
|
1.0, |
|
0.25266903914590755, |
|
-1.0, |
|
0.39799453934810447, |
|
0.0001414661638489788, |
|
0.03743668935364358, |
|
0.027419613352930545, |
|
0.14, |
|
0.00450387051372273, |
|
0.6666666666666666, |
|
1.0, |
|
], |
|
[ |
|
0.46153846153846156, |
|
0.4486433350453922, |
|
0.4382022471910112, |
|
0.46153846153846156, |
|
0.46153846153846156, |
|
0.8333333333333334, |
|
0.0, |
|
1.0000000000000002, |
|
0.0274390243902439, |
|
0.0, |
|
0.045607663417779054, |
|
0.0730876530735452, |
|
0.16024130487418112, |
|
0.095, |
|
0.010415200562983815, |
|
0.8333333333333334, |
|
0.0, |
|
], |
|
[ |
|
0.47008547008547014, |
|
0.463684509495124, |
|
0.4550561797752809, |
|
0.47008547008547014, |
|
0.47008547008547014, |
|
0.8333333333333334, |
|
0.058823529411764705, |
|
0.8220640569395019, |
|
0.05792682926829271, |
|
0.06368183245946799, |
|
0.08755897491589865, |
|
0.25113911501725356, |
|
0.36928580441210074, |
|
0.11, |
|
0.007741027445460941, |
|
0.8333333333333334, |
|
0.14285714285714285, |
|
], |
|
[ |
|
0.47863247863247865, |
|
0.46905198423091704, |
|
0.4606741573033708, |
|
0.47863247863247865, |
|
0.47863247863247865, |
|
0.8333333333333334, |
|
0.1176470588235294, |
|
0.7864768683274024, |
|
0.12195121951219515, |
|
0.08132988619614859, |
|
0.14999813621542551, |
|
0.2996905165894547, |
|
0.6364740024348741, |
|
0.08, |
|
0.007107670654468685, |
|
0.8333333333333334, |
|
-1.0, |
|
], |
|
[ |
|
0.4871794871794872, |
|
0.47317112992486215, |
|
0.4606741573033708, |
|
0.4871794871794872, |
|
0.4871794871794872, |
|
0.8333333333333334, |
|
-1.0, |
|
0.7864768683274024, |
|
0.1280487804878049, |
|
0.07948389590934354, |
|
0.16512012059265466, |
|
0.2686786265799859, |
|
0.6328933054607335, |
|
0.08, |
|
0.006896551724137932, |
|
0.8333333333333334, |
|
-1.0, |
|
], |
|
[ |
|
0.4957264957264958, |
|
0.47586507161735137, |
|
0.4606741573033708, |
|
0.4957264957264958, |
|
0.4957264957264958, |
|
0.8333333333333334, |
|
-1.0, |
|
0.7864768683274024, |
|
0.13109756097560973, |
|
0.07630898591345106, |
|
0.16512012059265466, |
|
0.3024866706067019, |
|
0.6460225276992488, |
|
0.06, |
|
0.006966924700914849, |
|
0.8333333333333334, |
|
-1.0, |
|
], |
|
[ |
|
0.5042735042735044, |
|
0.48720547768144135, |
|
0.47191011235955055, |
|
0.5042735042735044, |
|
0.5042735042735044, |
|
0.8333333333333334, |
|
-1.0, |
|
0.7508896797153027, |
|
0.13414634146341461, |
|
0.07882185227245272, |
|
0.1709737919644853, |
|
0.3240933152854302, |
|
0.5699753443436925, |
|
0.065, |
|
0.006755805770584096, |
|
0.8333333333333334, |
|
-1.0, |
|
], |
|
[ |
|
0.5128205128205129, |
|
0.4897837703618793, |
|
0.47191011235955055, |
|
0.5128205128205129, |
|
0.5128205128205129, |
|
0.8333333333333334, |
|
-1.0, |
|
0.7508896797153027, |
|
0.13109756097560973, |
|
0.08157634039674291, |
|
0.17707136631014223, |
|
0.3024866706067019, |
|
0.55735765024434, |
|
0.05500000000000001, |
|
-1.0, |
|
0.8333333333333334, |
|
-1.0, |
|
], |
|
[ |
|
0.5213675213675214, |
|
0.5080154969676149, |
|
0.4943820224719101, |
|
0.5213675213675214, |
|
0.5213675213675214, |
|
0.8333333333333334, |
|
-1.0, |
|
0.7508896797153027, |
|
0.14329268292682926, |
|
0.0845579529804045, |
|
0.18341284362962543, |
|
0.33832828119141584, |
|
0.35172333830083996, |
|
0.07, |
|
0.007248416608022519, |
|
0.8333333333333334, |
|
-1.0, |
|
], |
|
[ |
|
0.52991452991453, |
|
0.5134714091832119, |
|
0.5, |
|
0.52991452991453, |
|
0.52991452991453, |
|
0.8333333333333334, |
|
-1.0, |
|
0.7508896797153027, |
|
0.1524390243902439, |
|
0.08584821320704569, |
|
0.12780296559723434, |
|
0.27477932625397977, |
|
0.30653835267478063, |
|
0.09, |
|
0.0061928219563687536, |
|
0.8333333333333334, |
|
-1.0, |
|
], |
|
[ |
|
0.5384615384615385, |
|
0.5314514291156592, |
|
0.5224719101123595, |
|
0.5384615384615385, |
|
0.5384615384615385, |
|
0.8333333333333334, |
|
-1.0, |
|
0.7153024911032031, |
|
0.1524390243902439, |
|
0.10902940536883562, |
|
0.19268115663502394, |
|
0.3993352779313545, |
|
0.6039067109081672, |
|
0.07, |
|
0.009992962702322309, |
|
0.8333333333333334, |
|
-1.0, |
|
], |
|
[ |
|
0.5470085470085471, |
|
0.5371488436799516, |
|
0.5280898876404494, |
|
0.5470085470085471, |
|
0.5470085470085471, |
|
0.8333333333333334, |
|
-1.0, |
|
0.7153024911032031, |
|
0.1524390243902439, |
|
0.09519414308840943, |
|
0.20072995477129107, |
|
0.41077408982009295, |
|
0.596574807580165, |
|
0.105, |
|
0.0061928219563687536, |
|
0.8333333333333334, |
|
-1.0, |
|
], |
|
[ |
|
0.5555555555555557, |
|
0.5493089971529934, |
|
0.5449438202247191, |
|
0.5555555555555557, |
|
0.5555555555555557, |
|
0.8333333333333334, |
|
-1.0, |
|
0.7153024911032031, |
|
0.15853658536585366, |
|
0.09882330200304443, |
|
0.20853484993373195, |
|
0.42348388080758015, |
|
0.48352708882515627, |
|
0.09, |
|
0.005348346235045743, |
|
0.8333333333333334, |
|
-1.0, |
|
], |
|
[ |
|
0.5641025641025642, |
|
0.557574500073131, |
|
0.550561797752809, |
|
0.5641025641025642, |
|
0.5641025641025642, |
|
0.8333333333333334, |
|
-1.0, |
|
0.7153024911032031, |
|
0.16158536585365854, |
|
0.10281489356561238, |
|
0.21463242427938886, |
|
0.43949821745181405, |
|
0.509615023922466, |
|
0.13, |
|
0.004996481351161154, |
|
0.8333333333333334, |
|
-1.0, |
|
], |
|
[ |
|
0.5726495726495727, |
|
0.5654964573986455, |
|
0.5561797752808989, |
|
0.5726495726495727, |
|
0.5726495726495727, |
|
0.8333333333333334, |
|
-1.0, |
|
0.7153024911032031, |
|
0.16463414634146342, |
|
0.10698045279918819, |
|
0.22121780457269832, |
|
0.45271640007880076, |
|
0.596574807580165, |
|
0.065, |
|
0.005207600281491907, |
|
0.8333333333333334, |
|
-1.0, |
|
], |
|
[ |
|
0.5811965811965812, |
|
0.5711938719629379, |
|
0.5617977528089888, |
|
0.5811965811965812, |
|
0.5811965811965812, |
|
0.8333333333333334, |
|
-1.0, |
|
0.6797153024911033, |
|
0.1676829268292683, |
|
0.11068209824340977, |
|
0.22731537891835524, |
|
0.4585629039330449, |
|
0.37832280153731257, |
|
0.075, |
|
0.004644616467276566, |
|
0.8333333333333334, |
|
-1.0, |
|
], |
|
[ |
|
0.5897435897435899, |
|
0.5852078110703316, |
|
0.5786516853932584, |
|
0.5897435897435899, |
|
0.5897435897435899, |
|
0.8333333333333334, |
|
-1.0, |
|
0.6797153024911033, |
|
0.12195121951219515, |
|
0.11405997052214464, |
|
0.1699981800691802, |
|
0.2752877178934793, |
|
0.24975872922769482, |
|
0.065, |
|
0.004292751583391977, |
|
0.8333333333333334, |
|
-1.0, |
|
], |
|
[ |
|
0.5982905982905984, |
|
0.5917147687189832, |
|
0.5842696629213483, |
|
0.5982905982905984, |
|
0.5982905982905984, |
|
0.8333333333333334, |
|
-1.0, |
|
0.6441281138790037, |
|
0.17378048780487806, |
|
0.07403290888443234, |
|
0.2399983335573216, |
|
0.4885580106635147, |
|
0.6259024208921734, |
|
0.095, |
|
0.00422237860661506, |
|
0.8333333333333334, |
|
-1.0, |
|
], |
|
[ |
|
0.6068376068376069, |
|
0.6036980472324172, |
|
0.5955056179775281, |
|
0.6068376068376069, |
|
0.6068376068376069, |
|
0.8333333333333334, |
|
0.1764705882352941, |
|
0.6085409252669042, |
|
0.1829268292682927, |
|
0.14164834368279897, |
|
0.32438876250121335, |
|
0.6319244530023704, |
|
0.8306841859370685, |
|
0.07, |
|
0.0035186488388458817, |
|
0.8333333333333334, |
|
-1.0, |
|
], |
|
[ |
|
0.6153846153846155, |
|
0.6120587905154204, |
|
0.6067415730337078, |
|
0.6153846153846155, |
|
0.6153846153846155, |
|
0.8333333333333334, |
|
0.23529411764705882, |
|
0.5729537366548044, |
|
0.2439024390243902, |
|
0.1766593374731196, |
|
0.40731577360214744, |
|
0.8274010383899237, |
|
0.9764697055985051, |
|
0.08, |
|
0.003237156931738213, |
|
0.8333333333333334, |
|
-1.0, |
|
], |
|
[ |
|
0.623931623931624, |
|
0.6218957594228435, |
|
0.6179775280898876, |
|
0.623931623931624, |
|
0.623931623931624, |
|
0.8333333333333334, |
|
0.2941176470588235, |
|
0.5373665480427048, |
|
0.5060975609756098, |
|
0.19185251407446782, |
|
0.4707305467969794, |
|
0.9318755203070687, |
|
0.99300911543144, |
|
0.095, |
|
0.002674173117522871, |
|
0.8333333333333334, |
|
-1.0, |
|
], |
|
[ |
|
0.6324786324786326, |
|
0.6299469715265329, |
|
0.6235955056179775, |
|
0.6324786324786326, |
|
0.6324786324786326, |
|
0.8333333333333334, |
|
0.3529411764705882, |
|
0.5373665480427048, |
|
0.36585365853658536, |
|
0.19037862130620728, |
|
0.5121940523474464, |
|
0.8741730692238767, |
|
1.0, |
|
0.09, |
|
0.00302603800140746, |
|
0.8333333333333334, |
|
-1.0, |
|
], |
|
[ |
|
0.6410256410256411, |
|
0.6436309708054273, |
|
0.6404494382022472, |
|
0.6410256410256411, |
|
0.6410256410256411, |
|
0.8333333333333334, |
|
0.4117647058823529, |
|
0.5017793594306051, |
|
0.4573170731707318, |
|
0.21960035760021265, |
|
0.5512185281596508, |
|
0.8352811088021659, |
|
0.9004225222429487, |
|
0.08, |
|
0.0025334271639690367, |
|
0.8333333333333334, |
|
-1.0, |
|
], |
|
[ |
|
0.6495726495726497, |
|
0.650389635127367, |
|
0.6460674157303371, |
|
0.6495726495726497, |
|
0.6495726495726497, |
|
0.8333333333333334, |
|
0.47058823529411764, |
|
0.5017793594306051, |
|
0.4573170731707318, |
|
0.24515427549713678, |
|
0.5512185281596508, |
|
0.6868307500683152, |
|
0.8008450444858972, |
|
0.11, |
|
0.002603800140745954, |
|
0.8333333333333334, |
|
-1.0, |
|
], |
|
[ |
|
0.6581196581196582, |
|
0.6601415679965169, |
|
0.6573033707865168, |
|
0.6581196581196582, |
|
0.6581196581196582, |
|
0.8333333333333334, |
|
0.5294117647058824, |
|
0.4661921708185055, |
|
0.4817073170731707, |
|
0.2447531833667577, |
|
0.5243892010387603, |
|
0.5162653550162368, |
|
0.6980278885141472, |
|
0.14500000000000002, |
|
0.00274454609429979, |
|
0.8333333333333334, |
|
-1.0, |
|
], |
|
[ |
|
0.6666666666666667, |
|
0.6665464823992409, |
|
0.6629213483146067, |
|
0.6666666666666667, |
|
0.6666666666666667, |
|
0.8333333333333334, |
|
0.5882352941176471, |
|
0.4661921708185055, |
|
0.5609756097560976, |
|
0.25764612076255833, |
|
0.4707305467969794, |
|
0.33644214820887275, |
|
0.5328042995645191, |
|
0.09, |
|
0.002463054187192118, |
|
0.8333333333333334, |
|
-1.0, |
|
], |
|
[ |
|
0.6752136752136753, |
|
0.6788699050657669, |
|
0.6797752808988764, |
|
0.6752136752136753, |
|
0.6752136752136753, |
|
0.8333333333333334, |
|
0.6470588235294117, |
|
0.4661921708185055, |
|
0.39634146341463417, |
|
0.31621523666851903, |
|
0.3292668219777389, |
|
0.05598790027897992, |
|
0.1067013596417939, |
|
0.115, |
|
0.003237156931738213, |
|
0.8333333333333334, |
|
-1.0, |
|
], |
|
[ |
|
0.6837606837606839, |
|
0.6917715727925495, |
|
0.6910112359550562, |
|
0.6837606837606839, |
|
0.6837606837606839, |
|
0.8333333333333334, |
|
0.7058823529411764, |
|
0.5729537366548044, |
|
0.4085365853658537, |
|
0.10700461497571703, |
|
0.2902423461655346, |
|
0.14310589162361226, |
|
0.29698982741040586, |
|
0.125, |
|
0.002463054187192118, |
|
0.8333333333333334, |
|
0.2857142857142857, |
|
], |
|
[ |
|
0.6923076923076924, |
|
0.7013534335851533, |
|
0.7022471910112359, |
|
0.6923076923076924, |
|
0.6923076923076924, |
|
0.8333333333333334, |
|
0.7647058823529411, |
|
0.4661921708185055, |
|
0.4969512195121951, |
|
0.1702370309517481, |
|
0.27560816773595803, |
|
0.14910491296970624, |
|
0.3440504162133959, |
|
0.13, |
|
0.002463054187192118, |
|
0.8333333333333334, |
|
0.42857142857142855, |
|
], |
|
[ |
|
0.7008547008547009, |
|
0.7074079995101924, |
|
0.7078651685393258, |
|
0.7008547008547009, |
|
0.7008547008547009, |
|
0.8333333333333334, |
|
0.8235294117647058, |
|
0.3950177935943062, |
|
0.4024390243902439, |
|
0.1639017082658806, |
|
0.2392666246358428, |
|
0.13484961139814058, |
|
0.31250618096501487, |
|
0.08, |
|
0.0019704433497536944, |
|
0.8333333333333334, |
|
0.5714285714285714, |
|
], |
|
[ |
|
0.7094017094017095, |
|
0.7108774698717316, |
|
0.7078651685393258, |
|
0.7094017094017095, |
|
0.7094017094017095, |
|
0.8333333333333334, |
|
0.8823529411764706, |
|
0.35943060498220647, |
|
0.39634146341463417, |
|
0.21857588131538888, |
|
0.22731537891835524, |
|
0.13039610063612506, |
|
0.20985953437298585, |
|
0.15500000000000003, |
|
-1.0, |
|
0.8333333333333334, |
|
0.7142857142857142, |
|
], |
|
[ |
|
0.7179487179487181, |
|
0.7108774698717316, |
|
0.7022471910112359, |
|
0.7179487179487181, |
|
0.7179487179487181, |
|
0.8333333333333334, |
|
0.9411764705882353, |
|
0.32384341637010683, |
|
0.4573170731707318, |
|
0.26124628506535874, |
|
0.17072988899065902, |
|
0.14259749998411278, |
|
0.10329117204737433, |
|
0.09, |
|
-1.0, |
|
0.8333333333333334, |
|
0.8571428571428572, |
|
], |
|
[ |
|
0.7264957264957266, |
|
0.7516947682427814, |
|
0.7640449438202247, |
|
0.7264957264957266, |
|
0.7264957264957266, |
|
0.8333333333333334, |
|
1.0, |
|
0.2882562277580072, |
|
-1.0, |
|
0.3312441104694711, |
|
0.00023512490579826907, |
|
0.047782459217458176, |
|
0.03530908235262022, |
|
0.085, |
|
0.0, |
|
0.8333333333333334, |
|
1.0, |
|
], |
|
[ |
|
0.7350427350427351, |
|
0.7550962097737021, |
|
0.7640449438202247, |
|
0.7350427350427351, |
|
0.7350427350427351, |
|
0.9999999999999999, |
|
0.0, |
|
-1.0, |
|
0.0, |
|
0.00864039432672098, |
|
0.045607663417779054, |
|
0.0726936495529331, |
|
0.161264361152507, |
|
0.09, |
|
-1.0, |
|
0.9999999999999999, |
|
0.0, |
|
], |
|
[ |
|
0.7435897435897437, |
|
0.7653005343664645, |
|
0.7752808988764045, |
|
0.7435897435897437, |
|
0.7435897435897437, |
|
0.9999999999999999, |
|
0.058823529411764705, |
|
-1.0, |
|
0.06097560975609759, |
|
0.06690506680841815, |
|
0.1341444429167175, |
|
0.243767436244511, |
|
0.34200430365674417, |
|
0.06, |
|
-1.0, |
|
0.9999999999999999, |
|
0.14285714285714285, |
|
], |
|
[ |
|
0.7521367521367522, |
|
0.7687019758973853, |
|
0.7752808988764045, |
|
0.7521367521367522, |
|
0.7521367521367522, |
|
0.9999999999999999, |
|
0.1176470588235294, |
|
-1.0, |
|
0.12195121951219515, |
|
0.061666706936960886, |
|
0.24633981087680482, |
|
0.3327359731569215, |
|
0.5911185074290938, |
|
0.04, |
|
0.0018296973961998584, |
|
0.9999999999999999, |
|
-1.0, |
|
], |
|
[ |
|
0.7606837606837608, |
|
0.7858384383301644, |
|
0.797752808988764, |
|
0.7606837606837608, |
|
0.7606837606837608, |
|
0.9999999999999999, |
|
-1.0, |
|
-1.0, |
|
0.1829268292682927, |
|
0.11659699905767515, |
|
0.28536428668900904, |
|
0.5119440260804912, |
|
0.8622284211854495, |
|
0.045, |
|
0.001337086558761435, |
|
0.9999999999999999, |
|
-1.0, |
|
], |
|
[ |
|
0.7692307692307694, |
|
0.7824301939161817, |
|
0.7865168539325842, |
|
0.7692307692307694, |
|
0.7692307692307694, |
|
0.9999999999999999, |
|
-1.0, |
|
-1.0, |
|
0.2439024390243902, |
|
0.09646024113852172, |
|
0.3756083870047315, |
|
0.4725436740192808, |
|
0.7324707832177849, |
|
0.05500000000000001, |
|
-1.0, |
|
0.9999999999999999, |
|
-1.0, |
|
], |
|
[ |
|
0.7777777777777779, |
|
0.8062164745419109, |
|
0.8202247191011236, |
|
0.7777777777777779, |
|
0.7777777777777779, |
|
0.9999999999999999, |
|
-1.0, |
|
-1.0, |
|
0.2073170731707317, |
|
0.11115567690337544, |
|
0.4634134575821911, |
|
0.3535800303764005, |
|
0.7502037587087667, |
|
0.06, |
|
0.00154820548909219, |
|
0.9999999999999999, |
|
-1.0, |
|
], |
|
[ |
|
0.7863247863247864, |
|
0.8027163912065933, |
|
0.8089887640449438, |
|
0.7863247863247864, |
|
0.7863247863247864, |
|
0.9999999999999999, |
|
-1.0, |
|
-1.0, |
|
0.201219512195122, |
|
0.11461570058230844, |
|
0.49999890365613264, |
|
0.22851568705952632, |
|
0.7278670299653185, |
|
0.75, |
|
-1.0, |
|
0.9999999999999999, |
|
-1.0, |
|
], |
|
[ |
|
0.7948717948717949, |
|
0.826526481923039, |
|
0.8426966292134831, |
|
0.7948717948717949, |
|
0.7948717948717949, |
|
0.9999999999999999, |
|
-1.0, |
|
-1.0, |
|
0.17682926829268295, |
|
0.10304201802498372, |
|
0.4829256954882932, |
|
0.22851568705952632, |
|
0.5962337888207231, |
|
0.8, |
|
-1.0, |
|
0.9999999999999999, |
|
-1.0, |
|
], |
|
[ |
|
0.8034188034188036, |
|
0.8231250403921182, |
|
0.8314606741573034, |
|
0.8034188034188036, |
|
0.8034188034188036, |
|
0.9999999999999999, |
|
-1.0, |
|
-1.0, |
|
0.1829268292682927, |
|
0.10050982192475896, |
|
0.3341448814542644, |
|
0.3185010072509358, |
|
0.4903474640139954, |
|
0.65, |
|
-1.0, |
|
0.9999999999999999, |
|
-1.0, |
|
], |
|
[ |
|
0.8119658119658121, |
|
0.8367308065158015, |
|
0.848314606741573, |
|
0.8119658119658121, |
|
0.8119658119658121, |
|
0.9999999999999999, |
|
-1.0, |
|
-1.0, |
|
0.1829268292682927, |
|
0.10136516297388068, |
|
0.3292668219777389, |
|
0.3370573020926671, |
|
0.5761136820136477, |
|
0.65, |
|
-1.0, |
|
0.9999999999999999, |
|
-1.0, |
|
], |
|
[ |
|
0.8205128205128206, |
|
0.8367308065158015, |
|
0.8426966292134831, |
|
0.8205128205128206, |
|
0.8205128205128206, |
|
0.9999999999999999, |
|
-1.0, |
|
-1.0, |
|
0.1829268292682927, |
|
0.11133930944499479, |
|
0.3609742085751549, |
|
0.31646744069293786, |
|
0.16689117068329928, |
|
0.4, |
|
-1.0, |
|
0.9999999999999999, |
|
-1.0, |
|
], |
|
[ |
|
0.8290598290598292, |
|
0.8503365726394846, |
|
0.8595505617977528, |
|
0.8290598290598292, |
|
0.8290598290598292, |
|
0.9999999999999999, |
|
-1.0, |
|
-1.0, |
|
0.1829268292682927, |
|
0.11538889023123203, |
|
0.3682912977899432, |
|
0.48576185664626753, |
|
0.1992879528302852, |
|
0.6, |
|
-1.0, |
|
0.9999999999999999, |
|
-1.0, |
|
], |
|
[ |
|
0.8376068376068377, |
|
0.8537380141704054, |
|
0.8595505617977528, |
|
0.8376068376068377, |
|
0.8376068376068377, |
|
0.9999999999999999, |
|
-1.0, |
|
-1.0, |
|
0.1829268292682927, |
|
0.12207214825911519, |
|
0.3292668219777389, |
|
0.28443876740447005, |
|
-1.0, |
|
0.6, |
|
-1.0, |
|
0.9999999999999999, |
|
-1.0, |
|
], |
|
[ |
|
0.8461538461538463, |
|
0.8707452218250095, |
|
0.8820224719101123, |
|
0.8461538461538463, |
|
0.8461538461538463, |
|
0.9999999999999999, |
|
-1.0, |
|
-1.0, |
|
0.1829268292682927, |
|
0.12593809650373308, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
0.5, |
|
-1.0, |
|
0.9999999999999999, |
|
-1.0, |
|
], |
|
[ |
|
0.8547008547008548, |
|
0.8741466633559303, |
|
0.8820224719101123, |
|
0.8547008547008548, |
|
0.8547008547008548, |
|
0.9999999999999999, |
|
-1.0, |
|
-1.0, |
|
0.1829268292682927, |
|
0.12980404474835092, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
0.15000000000000002, |
|
-1.0, |
|
0.9999999999999999, |
|
-1.0, |
|
], |
|
[ |
|
0.8632478632478634, |
|
0.8775481048868511, |
|
0.8820224719101123, |
|
0.8632478632478634, |
|
0.8632478632478634, |
|
0.9999999999999999, |
|
-1.0, |
|
-1.0, |
|
0.1829268292682927, |
|
0.1331867494623916, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
0.35, |
|
-1.0, |
|
0.9999999999999999, |
|
-1.0, |
|
], |
|
[ |
|
0.8717948717948719, |
|
0.8877524294796135, |
|
0.8932584269662921, |
|
0.8717948717948719, |
|
0.8717948717948719, |
|
0.9999999999999999, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
1.0000000000000002, |
|
-1.0, |
|
0.9999999999999999, |
|
-1.0, |
|
], |
|
[ |
|
0.8803418803418804, |
|
0.8843509879486927, |
|
0.8820224719101123, |
|
0.8803418803418804, |
|
0.8803418803418804, |
|
0.9999999999999999, |
|
0.1764705882352941, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
0.4414621899378262, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
0.9999999999999999, |
|
-1.0, |
|
], |
|
[ |
|
0.8888888888888891, |
|
0.8877524294796135, |
|
0.8820224719101123, |
|
0.8888888888888891, |
|
0.8888888888888891, |
|
0.9999999999999999, |
|
0.23529411764705882, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
0.9512194052347446, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
0.9999999999999999, |
|
-1.0, |
|
], |
|
[ |
|
0.8974358974358976, |
|
0.9013581956032967, |
|
0.898876404494382, |
|
0.8974358974358976, |
|
0.8974358974358976, |
|
0.9999999999999999, |
|
0.2941176470588235, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
0.8536582157042338, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
0.9999999999999999, |
|
-1.0, |
|
], |
|
[ |
|
0.9059829059829061, |
|
0.894555312541455, |
|
0.8820224719101123, |
|
0.9059829059829061, |
|
0.9059829059829061, |
|
0.9999999999999999, |
|
0.3529411764705882, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
0.9024388104694893, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
0.9999999999999999, |
|
-1.0, |
|
], |
|
[ |
|
0.9145299145299146, |
|
0.9047596371342175, |
|
0.8932584269662921, |
|
0.9145299145299146, |
|
0.9145299145299146, |
|
0.9999999999999999, |
|
0.4117647058823529, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
0.9999999999999999, |
|
-1.0, |
|
], |
|
[ |
|
0.9230769230769232, |
|
0.9081610786651383, |
|
0.8932584269662921, |
|
0.9230769230769232, |
|
0.9230769230769232, |
|
0.9999999999999999, |
|
0.47058823529411764, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
0.8536582157042338, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
0.9999999999999999, |
|
-1.0, |
|
], |
|
[ |
|
0.9316239316239318, |
|
0.9183654032579007, |
|
0.9044943820224719, |
|
0.9316239316239318, |
|
0.9316239316239318, |
|
0.9999999999999999, |
|
0.5294117647058824, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
0.9999999999999999, |
|
-1.0, |
|
], |
|
[ |
|
0.9401709401709403, |
|
0.9217668447888215, |
|
0.9044943820224719, |
|
0.9401709401709403, |
|
0.9401709401709403, |
|
0.9999999999999999, |
|
0.5882352941176471, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
0.9999999999999999, |
|
-1.0, |
|
], |
|
[ |
|
0.9487179487179489, |
|
0.965985584690792, |
|
0.9719101123595505, |
|
0.9487179487179489, |
|
0.9487179487179489, |
|
0.9999999999999999, |
|
0.6470588235294117, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
0.9999999999999999, |
|
-1.0, |
|
], |
|
[ |
|
0.9572649572649574, |
|
0.9625841431598712, |
|
0.9606741573033708, |
|
0.9572649572649574, |
|
0.9572649572649574, |
|
0.9999999999999999, |
|
0.7058823529411764, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
0.9999999999999999, |
|
0.2857142857142857, |
|
], |
|
[ |
|
0.9658119658119659, |
|
0.9795913508144752, |
|
0.9831460674157303, |
|
0.9658119658119659, |
|
0.9658119658119659, |
|
0.9999999999999999, |
|
0.7647058823529411, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
0.9999999999999999, |
|
0.42857142857142855, |
|
], |
|
[ |
|
0.9743589743589745, |
|
0.9761899092835544, |
|
0.9719101123595505, |
|
0.9743589743589745, |
|
0.9743589743589745, |
|
0.9999999999999999, |
|
0.8235294117647058, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
0.9999999999999999, |
|
0.5714285714285714, |
|
], |
|
[ |
|
0.9829059829059831, |
|
0.9897956754072376, |
|
0.9887640449438202, |
|
0.9829059829059831, |
|
0.9829059829059831, |
|
0.9999999999999999, |
|
0.8823529411764706, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
0.9999999999999999, |
|
0.7142857142857142, |
|
], |
|
[ |
|
0.9914529914529915, |
|
1.0, |
|
1.0, |
|
0.9914529914529915, |
|
0.9914529914529915, |
|
0.9999999999999999, |
|
0.9411764705882353, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
0.9999999999999999, |
|
0.8571428571428572, |
|
], |
|
[ |
|
1.0000000000000002, |
|
0.9965985584690792, |
|
0.9887640449438202, |
|
1.0000000000000002, |
|
1.0000000000000002, |
|
0.9999999999999999, |
|
1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
-1.0, |
|
0.9999999999999999, |
|
1.0, |
|
], |
|
] |
|
|
|
|
|
@torch.jit.script |
|
def gaussian(x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor: |
|
"""Compute the Gaussian distribution probability density.""" |
|
pi = 3.14159 |
|
a = (2 * pi) ** 0.5 |
|
output: torch.Tensor = torch.exp(-0.5 * (((x - mean) / std) ** 2)) / (a * std) |
|
return output |
|
|
|
|
|
class GaussianLayer(nn.Module): |
|
"""Gaussian pairwise positional embedding layer.""" |
|
|
|
def __init__(self, k: int = 128, edge_types: int = 1024): |
|
super().__init__() |
|
self.k = k |
|
self.means = nn.Embedding(1, k) |
|
self.stds = nn.Embedding(1, k) |
|
self.mul = nn.Embedding(edge_types, 1) |
|
self.bias = nn.Embedding(edge_types, 1) |
|
nn.init.uniform_(self.means.weight, 0, 3) |
|
nn.init.uniform_(self.stds.weight, 0, 3) |
|
nn.init.constant_(self.bias.weight, 0) |
|
nn.init.constant_(self.mul.weight, 1) |
|
|
|
def forward(self, x: torch.Tensor, edge_types: int) -> torch.Tensor: |
|
"""Forward pass to compute the Gaussian pos. embeddings.""" |
|
mul = self.mul(edge_types) |
|
bias = self.bias(edge_types) |
|
x = mul * x.unsqueeze(-1) + bias |
|
x = x.expand(-1, -1, -1, self.k) |
|
mean = self.means.weight.float().view(-1) |
|
std = self.stds.weight.float().view(-1).abs() + 1e-5 |
|
output: torch.Tensor = gaussian(x.float(), mean, std).type_as(self.means.weight) |
|
return output |
|
|
|
|
|
class ParallelBlock(nn.Module): |
|
"""Parallel transformer block (MLP & Attention in parallel). |
|
|
|
Based on: |
|
'Scaling Vision Atomformers to 22 Billion Parameters` - https://arxiv.org/abs/2302.05442 |
|
|
|
Adapted from TIMM implementation. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
num_heads: int, |
|
mlp_ratio: int = 4, |
|
dropout: float = 0.0, |
|
k: int = 128, |
|
gradient_checkpointing: bool = False, |
|
): |
|
super().__init__() |
|
assert ( |
|
dim % num_heads == 0 |
|
), f"dim {dim} should be divisible by num_heads {num_heads}" |
|
self.num_heads = num_heads |
|
self.head_dim = dim // num_heads |
|
self.scale = self.head_dim**-0.5 |
|
self.mlp_hidden_dim = int(mlp_ratio * dim) |
|
self.proj_drop = nn.Dropout(dropout) |
|
self.attn_drop = nn.Dropout(dropout) |
|
self.gradient_checkpointing = gradient_checkpointing |
|
|
|
self.in_proj_in_dim = dim |
|
self.in_proj_out_dim = self.mlp_hidden_dim + 3 * dim |
|
self.out_proj_in_dim = self.mlp_hidden_dim + dim |
|
self.out_proj_out_dim = 2 * dim |
|
|
|
self.in_split = [self.mlp_hidden_dim] + [dim] * 3 |
|
self.out_split = [dim] * 2 |
|
|
|
self.in_norm = nn.LayerNorm(dim) |
|
self.q_norm = nn.LayerNorm(self.head_dim) |
|
self.k_norm = nn.LayerNorm(self.head_dim) |
|
self.in_proj = nn.Linear(self.in_proj_in_dim, self.in_proj_out_dim, bias=False) |
|
self.act_fn = nn.GELU() |
|
self.out_proj = nn.Linear( |
|
self.out_proj_in_dim, self.out_proj_out_dim, bias=False |
|
) |
|
self.gaussian_proj = nn.Linear(k, 1) |
|
self.pos_embed_ff_norm = nn.LayerNorm(k) |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
pos_embed: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
"""Forward pass for the parallel block.""" |
|
b, n, c = x.shape |
|
res = x |
|
|
|
|
|
x = self.in_proj(self.in_norm(x)) |
|
x, q, k, v = torch.split(x, self.in_split, dim=-1) |
|
x = self.act_fn(x) |
|
x = self.proj_drop(x) |
|
|
|
|
|
q = self.q_norm(q.view(b, n, self.num_heads, self.head_dim).transpose(1, 2)) |
|
k = self.k_norm(k.view(b, n, self.num_heads, self.head_dim).transpose(1, 2)) |
|
v = v.view(b, n, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
|
x_attn = ( |
|
f.scaled_dot_product_attention( |
|
q, |
|
k, |
|
v, |
|
attn_mask=attention_mask |
|
+ self.gaussian_proj(self.pos_embed_ff_norm(pos_embed)).permute( |
|
0, 3, 1, 2 |
|
), |
|
is_causal=False, |
|
) |
|
.transpose(1, 2) |
|
.reshape(b, n, c) |
|
) |
|
|
|
|
|
x_mlp, x_attn = self.out_proj(torch.cat([x, x_attn], dim=-1)).split( |
|
self.out_split, dim=-1 |
|
) |
|
|
|
x = x_mlp + x_attn + res |
|
del x_mlp, x_attn, res |
|
|
|
return x, pos_embed |
|
|
|
|
|
class AtomformerEncoder(nn.Module): |
|
"""Atomformer encoder. |
|
|
|
The transformer encoder consists of a series of parallel blocks, |
|
each containing a multi-head self-attention mechanism and a feed-forward network. |
|
""" |
|
|
|
def __init__(self, config: AtomformerConfig): |
|
super().__init__() |
|
self.vocab_size = config.vocab_size |
|
self.dim = config.dim |
|
self.num_heads = config.num_heads |
|
self.depth = config.depth |
|
self.mlp_ratio = config.mlp_ratio |
|
self.dropout = config.dropout |
|
self.k = config.k |
|
self.gradient_checkpointing = config.gradient_checkpointing |
|
|
|
self.metadata_vocab = nn.Embedding(self.vocab_size, 17) |
|
self.metadata_vocab.weight.requires_grad = False |
|
self.metadata_vocab.weight.fill_(-1) |
|
self.metadata_vocab.weight[1:-4] = torch.tensor( |
|
ATOM_METADATA, dtype=torch.float32 |
|
) |
|
self.embed_metadata = nn.Linear(17, self.dim) |
|
|
|
self.gaussian_embed = GaussianLayer( |
|
k=self.k, edge_types=(self.vocab_size + 1) ** 2 |
|
) |
|
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.dim) |
|
nn.init.normal_(self.embed_tokens.weight, std=0.02) |
|
|
|
self.blocks = nn.ModuleList() |
|
for _ in range(self.depth): |
|
self.blocks.append( |
|
ParallelBlock( |
|
self.dim, |
|
self.num_heads, |
|
self.mlp_ratio, |
|
self.dropout, |
|
self.k, |
|
self.gradient_checkpointing, |
|
) |
|
) |
|
|
|
def _expand_mask( |
|
self, |
|
mask: torch.Tensor, |
|
dtype: torch.dtype, |
|
device: torch.device, |
|
tgt_len: Optional[int] = None, |
|
) -> torch.Tensor: |
|
""" |
|
Expand attention mask. |
|
|
|
Expands attention_mask from `[bsz, seq_len]` to |
|
`[bsz, 1, tgt_seq_len, src_seq_len]`. |
|
""" |
|
bsz, src_len = mask.size() |
|
tgt_len = tgt_len if tgt_len is not None else src_len |
|
|
|
expanded_mask = ( |
|
mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) |
|
) |
|
|
|
inverted_mask: torch.Tensor = 1.0 - expanded_mask |
|
|
|
return inverted_mask.masked_fill( |
|
inverted_mask.to(torch.bool), torch.finfo(dtype).min |
|
).to(device) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
coords: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
"""Forward pass for the transformer encoder.""" |
|
|
|
coords_center = torch.sum(coords, dim=1, keepdim=True) / coords.shape[1] |
|
coords = torch.cat([coords_center, coords], dim=1) |
|
|
|
r_ij = torch.cdist(coords, coords, p=2) |
|
|
|
input_ids = torch.cat( |
|
[ |
|
torch.zeros( |
|
input_ids.size(0), 1, dtype=torch.long, device=input_ids.device |
|
).fill_(122), |
|
input_ids, |
|
], |
|
dim=1, |
|
) |
|
edge_type = input_ids.unsqueeze(-1) * self.vocab_size + input_ids.unsqueeze( |
|
-2 |
|
) |
|
pos_embeds = self.gaussian_embed(r_ij, edge_type) |
|
|
|
input_embeds = self.embed_tokens(input_ids) |
|
atom_metadata = self.metadata_vocab(input_ids) |
|
input_embeds = input_embeds + self.embed_metadata(atom_metadata) |
|
|
|
attention_mask = ( |
|
torch.cat( |
|
[ |
|
torch.ones( |
|
attention_mask.size(0), |
|
1, |
|
dtype=torch.bool, |
|
device=attention_mask.device, |
|
), |
|
attention_mask.bool(), |
|
], |
|
dim=1, |
|
) |
|
if attention_mask is not None |
|
else None |
|
) |
|
|
|
attention_mask = ( |
|
self._expand_mask(attention_mask, input_embeds.dtype, input_embeds.device) |
|
if attention_mask is not None |
|
else None |
|
) |
|
|
|
for blk in self.blocks: |
|
input_embeds, pos_embeds = blk(input_embeds, pos_embeds, attention_mask) |
|
|
|
return input_embeds, pos_embeds |
|
|
|
|
|
class AtomformerPreTrainedModel(PreTrainedModel): |
|
"""Base class for all transformer models.""" |
|
|
|
config_class = AtomformerConfig |
|
base_model_prefix = "model" |
|
supports_gradient_checkpointing = True |
|
_no_split_modules = ["ParallelBlock"] |
|
|
|
def _set_gradient_checkpointing( |
|
self, module: nn.Module, value: bool = False |
|
) -> None: |
|
if isinstance(module, (AtomformerEncoder)): |
|
module.gradient_checkpointing = value |
|
|
|
|
|
class AtomformerModel(AtomformerPreTrainedModel): |
|
"""Atomformer model for atom modeling.""" |
|
|
|
def __init__(self, config: AtomformerConfig): |
|
super().__init__(config) |
|
self.config = config |
|
self.encoder = AtomformerEncoder(config) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
coords: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
"""Forward function call for the transformer model.""" |
|
output: torch.Tensor = self.encoder(input_ids, coords, attention_mask) |
|
return output |
|
|
|
|
|
class AtomformerForMaskedAM(AtomformerPreTrainedModel): |
|
"""Atomformer with an atom modeling head on top for masked atom modeling.""" |
|
|
|
def __init__(self, config: AtomformerConfig): |
|
super().__init__(config) |
|
self.config = config |
|
self.encoder = AtomformerEncoder(config) |
|
self.am_head = nn.Linear(config.dim, config.vocab_size, bias=False) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
coords: torch.Tensor, |
|
labels: Optional[torch.Tensor] = None, |
|
fixed: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
) -> Tuple[Optional[torch.Tensor], torch.Tensor]: |
|
"""Forward function call for the masked atom modeling model.""" |
|
hidden_states = self.encoder(input_ids, coords, attention_mask) |
|
logits = self.am_head(hidden_states) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss_fct = nn.CrossEntropyLoss() |
|
logits, labels = logits.view(-1, self.config.vocab_size), labels.view(-1) |
|
loss = loss_fct(logits, labels) |
|
|
|
return loss, logits |
|
|
|
|
|
class AtomformerForCoordinateAM(AtomformerPreTrainedModel): |
|
"""Atomformer with an atom coordinate head on top for coordinate denoising.""" |
|
|
|
def __init__(self, config: AtomformerConfig): |
|
super().__init__(config) |
|
self.config = config |
|
self.encoder = AtomformerEncoder(config) |
|
self.coords_head = nn.Linear(config.dim, 3) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
coords: torch.Tensor, |
|
labels_coords: Optional[torch.Tensor] = None, |
|
fixed: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
) -> Tuple[Optional[torch.Tensor], torch.Tensor]: |
|
"""Forward function call for the coordinate atom modeling model.""" |
|
hidden_states = self.encoder(input_ids, coords, attention_mask) |
|
coords_pred = self.coords_head(hidden_states) |
|
|
|
loss = None |
|
if labels_coords is not None: |
|
labels_coords = labels_coords.to(coords_pred.device) |
|
loss_fct = nn.L1Loss() |
|
loss = loss_fct(coords_pred, labels_coords) |
|
|
|
return loss, coords_pred |
|
|
|
|
|
class InitialStructure2RelaxedStructure(AtomformerPreTrainedModel): |
|
"""Atomformer with an coordinate head on top for relaxed structure prediction.""" |
|
|
|
def __init__(self, config: AtomformerConfig): |
|
super().__init__(config) |
|
self.config = config |
|
self.encoder = AtomformerEncoder(config) |
|
self.coords_head = nn.Linear(config.dim, 3) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
coords: torch.Tensor, |
|
labels_coords: Optional[torch.Tensor] = None, |
|
fixed: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
) -> Tuple[Optional[torch.Tensor], torch.Tensor]: |
|
"""Forward function call. |
|
|
|
Initial structure to relaxed structure model. |
|
""" |
|
hidden_states = self.encoder(input_ids, coords, attention_mask) |
|
coords_pred = self.coords_head(hidden_states) |
|
|
|
loss = None |
|
if labels_coords is not None: |
|
labels_coords = labels_coords.to(coords_pred.device) |
|
loss_fct = nn.L1Loss() |
|
loss = loss_fct(coords_pred, labels_coords) |
|
|
|
return loss, coords_pred |
|
|
|
|
|
class InitialStructure2RelaxedEnergy(AtomformerPreTrainedModel): |
|
"""Atomformer with an energy head on top for relaxed energy prediction.""" |
|
|
|
def __init__(self, config: AtomformerConfig): |
|
super().__init__(config) |
|
self.config = config |
|
self.encoder = AtomformerEncoder(config) |
|
self.energy_norm = nn.LayerNorm(config.dim) |
|
self.energy_head = nn.Linear(config.dim, 1, bias=False) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
coords: torch.Tensor, |
|
labels_energy: Optional[torch.Tensor] = None, |
|
fixed: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
) -> Tuple[Optional[torch.Tensor], torch.Tensor]: |
|
"""Forward function call for the relaxed energy prediction model.""" |
|
hidden_states = self.encoder(input_ids, coords, attention_mask) |
|
energy = self.energy_head(self.energy_norm(hidden_states[:, 0])).squeeze(-1) |
|
|
|
loss = None |
|
if labels_energy is not None: |
|
loss_fct = nn.L1Loss() |
|
loss = loss_fct(energy, labels_energy) |
|
|
|
return loss, energy |
|
|
|
|
|
class InitialStructure2RelaxedStructureAndEnergy(AtomformerPreTrainedModel): |
|
"""Atomformer with an coordinate and energy head.""" |
|
|
|
def __init__(self, config: AtomformerConfig): |
|
super().__init__(config) |
|
self.config = config |
|
self.encoder = AtomformerEncoder(config) |
|
self.energy_norm = nn.LayerNorm(config.dim) |
|
self.energy_head = nn.Linear(config.dim, 1, bias=False) |
|
self.coords_head = nn.Linear(config.dim, 3) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
coords: torch.Tensor, |
|
labels_coords: Optional[torch.Tensor] = None, |
|
forces: Optional[torch.Tensor] = None, |
|
total_energy: Optional[torch.Tensor] = None, |
|
formation_energy: Optional[torch.Tensor] = None, |
|
has_formation_energy: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
|
"""Forward function call for the relaxed structure and energy model.""" |
|
atom_hidden_states, pos_hidden_states = self.encoder( |
|
input_ids, coords, attention_mask |
|
) |
|
|
|
formation_energy_pred = self.formation_energy_head( |
|
self.energy_norm(atom_hidden_states[:, 0]) |
|
).squeeze(-1) |
|
loss_formation_energy = None |
|
if formation_energy is not None: |
|
loss_fct = nn.L1Loss() |
|
loss_formation_energy = loss_fct( |
|
formation_energy_pred[has_formation_energy], |
|
formation_energy[has_formation_energy], |
|
) |
|
coords_pred = self.coords_head(atom_hidden_states[:, 1:]) |
|
loss_coords = None |
|
if labels_coords is not None: |
|
loss_fct = nn.L1Loss() |
|
loss_coords = loss_fct(coords_pred, labels_coords) |
|
|
|
loss = torch.Tensor(0).to(coords.device) |
|
loss = ( |
|
loss + loss_formation_energy if loss_formation_energy is not None else loss |
|
) |
|
loss = loss + loss_coords if loss_coords is not None else loss |
|
|
|
return loss, (formation_energy_pred, coords_pred) |
|
|
|
|
|
class Structure2Energy(AtomformerPreTrainedModel): |
|
"""Atomformer with an atom modeling head on top for masked atom modeling.""" |
|
|
|
def __init__(self, config: AtomformerConfig): |
|
super().__init__(config) |
|
self.config = config |
|
self.encoder = AtomformerEncoder(config) |
|
self.energy_norm = nn.LayerNorm(config.dim) |
|
self.formation_energy_head = nn.Linear(config.dim, 1, bias=False) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
coords: torch.Tensor, |
|
forces: Optional[torch.Tensor] = None, |
|
total_energy: Optional[torch.Tensor] = None, |
|
formation_energy: Optional[torch.Tensor] = None, |
|
has_formation_energy: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
) -> Tuple[Optional[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]]]: |
|
"""Forward function call for the structure to energy model.""" |
|
atom_hidden_states, pos_hidden_states = self.encoder( |
|
input_ids, coords, attention_mask |
|
) |
|
|
|
formation_energy_pred: torch.Tensor = self.formation_energy_head( |
|
self.energy_norm(atom_hidden_states[:, 0]) |
|
).squeeze(-1) |
|
loss = torch.Tensor(0).to(coords.device) |
|
if formation_energy is not None: |
|
loss_fct = nn.L1Loss() |
|
loss = loss_fct( |
|
formation_energy_pred[has_formation_energy], |
|
formation_energy[has_formation_energy], |
|
) |
|
|
|
return loss, ( |
|
formation_energy_pred, |
|
attention_mask.bool() if attention_mask is not None else None, |
|
) |
|
|
|
|
|
class Structure2Forces(AtomformerPreTrainedModel): |
|
"""Atomformer with a forces head on top for forces prediction.""" |
|
|
|
def __init__(self, config: AtomformerConfig): |
|
super().__init__(config) |
|
self.config = config |
|
self.encoder = AtomformerEncoder(config) |
|
self.force_norm = nn.LayerNorm(config.dim) |
|
self.force_head = nn.Linear(config.dim, 3) |
|
self.energy_norm = nn.LayerNorm(config.dim) |
|
self.formation_energy_head = nn.Linear(config.dim, 1, bias=False) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
coords: torch.Tensor, |
|
forces: Optional[torch.Tensor] = None, |
|
total_energy: Optional[torch.Tensor] = None, |
|
formation_energy: Optional[torch.Tensor] = None, |
|
has_formation_energy: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]: |
|
"""Forward function call for the structure to forces model.""" |
|
atom_hidden_states, pos_hidden_states = self.encoder( |
|
input_ids, coords, attention_mask |
|
) |
|
attention_mask = attention_mask.bool() if attention_mask is not None else None |
|
|
|
forces_pred: torch.Tensor = self.force_head( |
|
self.force_norm(atom_hidden_states[:, 1:]) |
|
) |
|
loss = torch.Tensor(0).to(coords.device) |
|
if forces is not None: |
|
loss_fct = nn.L1Loss() |
|
loss = loss_fct(forces_pred[attention_mask], forces[attention_mask]) |
|
|
|
return loss, ( |
|
forces_pred, |
|
attention_mask if attention_mask is not None else None, |
|
) |
|
|
|
|
|
class Structure2EnergyAndForces(AtomformerPreTrainedModel): |
|
"""Atomformer with an energy and forces head for energy and forces prediction.""" |
|
|
|
def __init__(self, config: AtomformerConfig): |
|
super().__init__(config) |
|
self.config = config |
|
self.encoder = AtomformerEncoder(config) |
|
self.force_norm = nn.LayerNorm(config.dim) |
|
self.force_head = nn.Linear(config.dim, 3) |
|
self.energy_norm = nn.LayerNorm(config.dim) |
|
self.formation_energy_head = nn.Linear(config.dim, 1, bias=False) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
coords: torch.Tensor, |
|
forces: Optional[torch.Tensor] = None, |
|
total_energy: Optional[torch.Tensor] = None, |
|
formation_energy: Optional[torch.Tensor] = None, |
|
has_formation_energy: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]]: |
|
"""Forward function call for the structure to energy and forces model.""" |
|
atom_hidden_states, pos_hidden_states = self.encoder( |
|
input_ids, coords, attention_mask |
|
) |
|
|
|
formation_energy_pred: torch.Tensor = self.formation_energy_head( |
|
self.energy_norm(atom_hidden_states[:, 0]) |
|
).squeeze(-1) |
|
loss_formation_energy = None |
|
if formation_energy is not None: |
|
loss_fct = nn.L1Loss() |
|
loss_formation_energy = loss_fct( |
|
formation_energy_pred[has_formation_energy], |
|
formation_energy[has_formation_energy], |
|
) |
|
attention_mask = attention_mask.bool() if attention_mask is not None else None |
|
forces_pred: torch.Tensor = self.force_head( |
|
self.force_norm(atom_hidden_states[:, 1:]) |
|
) |
|
loss_forces = None |
|
if forces is not None: |
|
loss_fct = nn.L1Loss() |
|
loss_forces = loss_fct(forces_pred[attention_mask], forces[attention_mask]) |
|
|
|
loss = torch.Tensor(0).to(coords.device) |
|
loss = ( |
|
loss + loss_formation_energy if loss_formation_energy is not None else loss |
|
) |
|
loss = loss + loss_forces if loss_forces is not None else loss |
|
|
|
return loss, (formation_energy_pred, forces_pred, attention_mask) |
|
|