atomformer-base / modeling_atomformer.py
akore's picture
Update modeling_atomformer.py
e28f6a5 verified
"""Implementation of the Atomformer model."""
from typing import Any, Optional, Tuple
import torch
import torch.nn.functional as f
from torch import nn
from transformers.modeling_utils import PreTrainedModel
from .configuration_atomformer import AtomformerConfig
ATOM_METADATA = [
[
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.106761565836299,
0.4573170731707318,
0.46896368424867707,
0.0,
0.0,
0.0027383806383189145,
0.0,
1.0,
0.0,
0.0,
],
[
0.008547008547008548,
0.010187317385107808,
0.011235955056179775,
0.008547008547008548,
0.008547008547008548,
0.0,
1.0,
0.0,
-1.0,
0.9999999999999999,
2.1731754967921256e-06,
-1.0,
0.0,
0.010000000000000002,
0.3588318085855031,
0.0,
-1.0,
],
[
0.017094017094017096,
0.02018415404448405,
0.02247191011235955,
0.017094017094017096,
0.017094017094017096,
0.16666666666666666,
0.0,
0.5729537366548044,
0.08536585365853658,
0.0723802160098582,
0.01302222611458848,
0.1117635470484688,
0.2746530986669577,
0.010000000000000002,
0.2454609429978888,
0.16666666666666666,
0.0,
],
[
0.025641025641025644,
0.027228539455021038,
0.028089887640449437,
0.025641025641025644,
0.025641025641025644,
0.16666666666666666,
0.058823529411764705,
0.32384341637010683,
0.2652439024390244,
0.2623432478797689,
0.0451198574701265,
0.39298038243761085,
0.4668171696125004,
0.015,
0.12181562280084446,
0.16666666666666666,
0.14285714285714285,
],
[
0.03418803418803419,
0.03334773276914757,
0.033707865168539325,
0.03418803418803419,
0.03418803418803419,
0.16666666666666666,
0.7058823529411764,
0.25266903914590755,
0.4085365853658537,
0.2128252833015198,
0.057071103187614054,
0.6504807478441018,
0.715419845245687,
0.015,
0.06558761435608726,
0.16666666666666666,
0.2857142857142857,
],
[
0.042735042735042736,
0.03742946260625253,
0.033707865168539325,
0.042735042735042736,
0.042735042735042736,
0.16666666666666666,
0.7647058823529411,
0.14946619217081855,
0.5640243902439024,
0.3559765143644139,
0.055363782370830124,
1.0000000000000002,
0.7324707832177849,
0.020000000000000004,
0.04327938071780436,
0.16666666666666666,
0.42857142857142855,
],
[
0.051282051282051294,
0.04421873990197045,
0.03932584269662921,
0.051282051282051294,
0.051282051282051294,
0.16666666666666666,
0.8235294117647058,
0.09252669039145908,
0.7134146341463414,
0.514180781404789,
2.8295183993586364e-05,
0.012484827687008686,
0.012471056032792366,
0.025,
0.06657283603096412,
0.16666666666666666,
0.5714285714285714,
],
[
0.05982905982905984,
0.050994411431564704,
0.0449438202247191,
0.05982905982905984,
0.05982905982905984,
0.16666666666666666,
0.8823529411764706,
0.05693950177935947,
0.8353658536585367,
0.4699156740039142,
3.268543752245935e-05,
0.00923366315240946,
0.014660396468409729,
0.025,
0.057987332864180154,
0.16666666666666666,
0.7142857142857142,
],
[
0.06837606837606838,
0.06119533458279619,
0.056179775280898875,
0.06837606837606838,
0.06837606837606838,
0.16666666666666666,
0.9411764705882353,
0.028469750889679707,
1.0,
0.6537753400826345,
3.9270817815768815e-05,
0.01002929606822616,
0.01377886297525227,
0.015,
0.051372273047149884,
0.16666666666666666,
0.8571428571428572,
],
[
0.07692307692307693,
0.06521583847234458,
0.056179775280898875,
0.07692307692307693,
0.07692307692307693,
0.16666666666666666,
1.0,
0.007117437722419961,
-1.0,
0.8539203131418077,
1.9758579909666677e-05,
0.002676173590325307,
0.003896139326624358,
0.025,
0.06586910626319493,
0.16666666666666666,
1.0,
],
[
0.08547008547008549,
0.07477388917423204,
0.06741573033707865,
0.08547008547008549,
0.08547008547008549,
0.33333333333333337,
0.0,
0.6085409252669042,
0.07012195121951223,
0.06017348442747725,
0.023680786070796774,
0.09074155275516495,
0.19638929337502856,
0.020000000000000004,
0.07980295566502463,
0.33333333333333337,
0.0,
],
[
0.09401709401709403,
0.0792467847873929,
0.06741573033707865,
0.09401709401709403,
0.09401709401709403,
0.33333333333333337,
0.058823529411764705,
0.43060498220640586,
0.1859756097560976,
0.18132746997849566,
0.04243692475803745,
0.23105764525702377,
0.2316847349772711,
0.025,
0.0653764954257565,
0.33333333333333337,
0.14285714285714285,
],
[
0.10256410256410257,
0.08835244376566789,
0.07865168539325842,
0.10256410256410257,
0.10256410256410257,
0.33333333333333337,
0.7058823529411764,
0.4661921708185055,
0.2774390243902439,
0.10108971416145165,
0.06585161024536003,
0.23366315240945865,
0.4753426385985493,
0.025,
0.05650950035186488,
0.33333333333333337,
0.2857142857142857,
],
[
0.11111111111111113,
0.09210763521580445,
0.07865168539325842,
0.11111111111111113,
0.11111111111111113,
0.33333333333333337,
0.7647058823529411,
0.35943060498220647,
0.36585365853658536,
0.20575543044917485,
0.05682720021378778,
0.4242464682668294,
0.6025426358703994,
0.025,
0.04299788881069668,
0.33333333333333337,
0.42857142857142855,
],
[
0.11965811965811968,
0.10193099835710374,
0.0898876404494382,
0.11965811965811968,
0.11965811965811968,
0.33333333333333337,
0.8235294117647058,
0.25266903914590755,
0.4542682926829268,
0.3185927948389591,
0.04438814854864767,
0.07704039807065373,
0.09357213740327856,
0.020000000000000004,
0.04750175932441942,
0.33333333333333337,
0.5714285714285714,
],
[
0.12820512820512822,
0.10564197106733833,
0.0898876404494382,
0.12820512820512822,
0.12820512820512822,
0.33333333333333337,
0.8823529411764706,
0.21708185053380794,
0.5731707317073171,
0.31247009930654546,
0.05048572289430458,
0.09515439218602051,
0.1216720831812958,
0.035,
0.04334975369458127,
0.33333333333333337,
0.7142857142857142,
],
[
0.13675213675213677,
0.11716605497409803,
0.10112359550561797,
0.13675213675213677,
0.13675213675213677,
0.33333333333333337,
0.9411764705882353,
0.17081850533807832,
0.75,
0.43848068233986515,
7.610016686353661e-05,
0.04019725595612581,
0.040050948202660634,
0.04,
0.0270935960591133,
0.33333333333333337,
0.8571428571428572,
],
[
0.1452991452991453,
0.13245553465558702,
0.12359550561797752,
0.1452991452991453,
0.1452991452991453,
0.33333333333333337,
1.0,
0.13879003558718864,
-1.0,
0.573402276077029,
4.1222041606379033e-05,
0.0177390552812359,
0.01416591926721889,
0.025,
0.029978888106966927,
0.33333333333333337,
1.0,
],
[
0.15384615384615385,
0.12956430935430432,
0.11235955056179775,
0.15384615384615385,
0.15384615384615385,
0.5,
0.0,
0.8220640569395019,
0.036585365853658514,
0.021591320946190845,
0.02102224365609036,
0.08193366760083631,
0.17524613028962724,
0.035,
0.04665728360309641,
0.5,
0.0,
],
[
0.1623931623931624,
0.13289772205460673,
0.11235955056179775,
0.1623931623931624,
0.1623931623931624,
0.5,
0.058823529411764705,
0.6085409252669042,
0.09146341463414634,
0.10724623674100561,
0.03755886528151192,
0.27910065518972543,
0.2988654305873366,
0.05500000000000001,
0.038916256157635463,
0.5,
0.14285714285714285,
],
[
0.17094017094017097,
0.14948995384243843,
0.1348314606741573,
0.17094017094017097,
0.17094017094017097,
0.5,
0.1176470588235294,
0.5729537366548044,
0.201219512195122,
0.128910044216783,
0.07292479648632205,
0.4570377290145464,
0.5293941119700996,
0.06,
0.03335679099225897,
0.5,
-1.0,
],
[
0.17948717948717952,
0.15939155013894887,
0.14606741573033707,
0.17948717948717952,
0.17948717948717952,
0.5,
0.1764705882352941,
0.5373665480427048,
0.25609756097560976,
0.14179331674197213,
0.11072975742939495,
0.48779542320426544,
0.6062938422242609,
0.03,
0.03019000703729768,
0.5,
-1.0,
],
[
0.18803418803418806,
0.16985098284653036,
0.15730337078651685,
0.18803418803418806,
0.18803418803418806,
0.5,
0.23529411764705882,
0.5017793594306051,
0.2835365853658536,
0.13783555222654456,
0.14902252432012042,
0.5493108115837035,
0.6267549677907782,
0.03,
0.027797325826882473,
0.5,
-1.0,
],
[
0.1965811965811966,
0.17343610222012087,
0.15730337078651685,
0.1965811965811966,
0.1965811965811966,
0.5,
0.2941176470588235,
0.5017793594306051,
0.29268292682926833,
0.13881653659361634,
0.1743884335980532,
0.5378719996949651,
0.5012600643161381,
0.03,
0.024982406755805767,
0.5,
-1.0,
],
[
0.20512820512820515,
0.1834431432040899,
0.16853932584269662,
0.20512820512820515,
0.20512820512820515,
0.5,
0.3529411764705882,
0.4661921708185055,
0.25914634146341464,
0.17107304225964678,
0.1814616198390152,
0.38255835382787134,
0.3972493426863412,
0.04,
0.0270935960591133,
0.5,
-1.0,
],
[
0.2136752136752137,
0.18652825067263504,
0.16853932584269662,
0.2136752136752137,
0.2136752136752137,
0.5,
0.4117647058823529,
0.43060498220640586,
0.3445121951219513,
0.19370816923188441,
0.1919494477135451,
0.4560209457355474,
0.5336568464631241,
0.035,
0.024982406755805767,
0.5,
-1.0,
],
[
0.22222222222222224,
0.19703190212011848,
0.1797752808988764,
0.22222222222222224,
0.22222222222222224,
0.5,
0.47058823529411764,
0.43060498220640586,
0.3597560975609756,
0.19267402807644915,
0.2160958421223465,
0.4458531129455577,
0.5449104655247086,
0.05500000000000001,
0.023011963406052074,
0.5,
-1.0,
],
[
0.2307692307692308,
0.19621555615269748,
0.1741573033707865,
0.2307692307692308,
0.2307692307692308,
0.5,
0.5294117647058824,
0.3950177935943062,
0.36890243902439024,
0.18101819411892625,
0.2173153569914779,
0.4351768885160684,
0.5425233342086149,
0.04,
0.024630541871921183,
0.5,
-1.0,
],
[
0.23931623931623935,
0.21272275190225615,
0.19662921348314605,
0.23931623931623935,
0.23931623931623935,
0.5,
0.5882352941176471,
0.3950177935943062,
0.36585365853658536,
0.1852030830937251,
0.2185348718606093,
0.3415311485202626,
0.4826745419265514,
0.04,
0.02047853624208304,
0.5,
-1.0,
],
[
0.2478632478632479,
0.2189609956699649,
0.19662921348314605,
0.2478632478632479,
0.2478632478632479,
0.5,
0.6470588235294117,
0.35943060498220647,
0.28963414634146345,
0.2657984391233962,
0.17390062765040062,
0.17252397384325016,
0.20048151848833204,
0.06,
0.020689655172413793,
0.5,
-1.0,
],
[
0.2564102564102564,
0.23373345623875397,
0.2191011235955056,
0.2564102564102564,
0.2564102564102564,
0.5,
0.7058823529411764,
0.4661921708185055,
0.33841463414634154,
0.10174209292773093,
0.14414446484359486,
0.0733952300154424,
0.4216321839864411,
0.05500000000000001,
0.019493314567206193,
0.5,
0.2857142857142857,
],
[
0.26495726495726496,
0.24365546118444995,
0.2303370786516854,
0.26495726495726496,
0.26495726495726496,
0.5,
0.7647058823529411,
0.35943060498220647,
0.39939024390243894,
0.19356319617271125,
0.12975418938784455,
0.30434230009087504,
0.5288825838309366,
0.07,
0.015904292751583393,
0.5,
0.42857142857142855,
],
[
0.27350427350427353,
0.2514175507580112,
0.23595505617977527,
0.27350427350427353,
0.27350427350427353,
0.5,
0.8235294117647058,
0.2882562277580072,
0.451219512195122,
0.28485756396936235,
0.1409737261838533,
0.2735083471552311,
0.15052227023008535,
0.05500000000000001,
0.016537649542575653,
0.5,
0.5714285714285714,
],
[
0.28205128205128205,
0.2651525716598694,
0.25280898876404495,
0.28205128205128205,
0.28205128205128205,
0.5,
0.8823529411764706,
0.25266903914590755,
0.5640243902439024,
0.2831082223886727,
0.11731513772270441,
0.12200763858438347,
0.1626284361902748,
0.085,
0.01597466572836031,
0.5,
0.7142857142857142,
],
[
0.2905982905982906,
0.2683635324650587,
0.25280898876404495,
0.2905982905982906,
0.2905982905982906,
0.5,
0.9411764705882353,
0.21708185053380794,
0.6890243902439025,
0.38272404378186387,
0.07609553514606364,
0.06402557209946683,
0.055889564484942325,
0.08,
0.026741731175228708,
0.5,
0.8571428571428572,
],
[
0.29914529914529914,
0.2816087457864643,
0.2696629213483146,
0.29914529914529914,
0.29914529914529914,
0.5,
1.0,
0.18149466192170824,
-1.0,
0.48835141469543564,
8.878312150250298e-05,
0.025865695638635226,
0.019729640327514418,
0.1,
0.010837438423645322,
0.5,
1.0,
],
[
0.3076923076923077,
0.28728915314310205,
0.2696629213483146,
0.3076923076923077,
0.3076923076923077,
0.6666666666666666,
0.0,
0.8932384341637012,
0.036585365853658514,
0.013685456785947292,
0.03731496230768564,
0.07590668471456988,
0.16313996432943775,
0.085,
0.01893033075299085,
0.6666666666666666,
0.0,
],
[
0.3162393162393162,
0.2946090553176436,
0.2808988764044944,
0.3162393162393162,
0.3162393162393162,
0.6666666666666666,
0.058823529411764705,
0.7153024911032031,
0.07621951219512194,
0.08703215985695992,
0.06438819240240236,
0.26130694780724334,
0.2814734738557968,
0.075,
0.014567206192821956,
0.6666666666666666,
0.14285714285714285,
],
[
0.3247863247863248,
0.29898330912640775,
0.2808988764044944,
0.3247863247863248,
0.3247863247863248,
0.6666666666666666,
0.1176470588235294,
0.6441281138790037,
0.15853658536585366,
0.11227680189431463,
0.109022436612611,
0.4537331833577997,
0.6146488018305888,
0.09,
0.014356087262491202,
0.6666666666666666,
-1.0,
],
[
0.3333333333333333,
0.3068678505950822,
0.28651685393258425,
0.3333333333333333,
0.3333333333333333,
0.6666666666666666,
0.1764705882352941,
0.6085409252669042,
0.19207317073170735,
0.1324087273781622,
0.15877864327317145,
0.5366010205962164,
0.7976053662711987,
0.085,
0.012948627726952853,
0.6666666666666666,
-1.0,
],
[
0.3418803418803419,
0.312589075250091,
0.29213483146067415,
0.3418803418803419,
0.3418803418803419,
0.6666666666666666,
0.23529411764705882,
0.5729537366548044,
0.27439024390243905,
0.13844927151037764,
0.2090226558813845,
0.6931856455620589,
0.8547260084777264,
0.105,
0.012033779028852921,
0.6666666666666666,
-1.0,
],
[
0.35042735042735046,
0.3229770776855231,
0.3033707865168539,
0.35042735042735046,
0.35042735042735046,
0.6666666666666666,
0.2941176470588235,
0.5373665480427048,
0.44512195121951226,
0.15456544325512842,
0.24877884061506758,
0.7310608227047707,
0.8368225236070237,
0.085,
0.011048557353976075,
0.6666666666666666,
-1.0,
],
[
0.358974358974359,
0.3299160184086015,
0.3089887640449438,
0.358974358974359,
0.358974358974359,
0.6666666666666666,
0.3529411764705882,
0.5373665480427048,
0.36585365853658536,
0.16363109188875738,
0.28048622721248356,
0.6250611658691273,
0.8774037559806166,
0.1,
-1.0,
0.6666666666666666,
-1.0,
],
[
0.36752136752136755,
0.3403584439085284,
0.3202247191011236,
0.36752136752136755,
0.36752136752136755,
0.6666666666666666,
0.4117647058823529,
0.5017793594306051,
0.4573170731707318,
0.16752120230990405,
0.30243749485684845,
0.6377709568566146,
0.7534434369234653,
0.065,
0.010133708655876143,
0.6666666666666666,
-1.0,
],
[
0.37606837606837606,
0.346603490559299,
0.3258426966292135,
0.37606837606837606,
0.37606837606837606,
0.6666666666666666,
0.47058823529411764,
0.4661921708185055,
0.4817073170731707,
0.17227631865078405,
0.30243749485684845,
0.5655793440476872,
0.67586166915042,
0.085,
0.01048557353976073,
0.6666666666666666,
-1.0,
],
[
0.38461538461538464,
0.35855615609895475,
0.33707865168539325,
0.38461538461538464,
0.38461538461538464,
0.6666666666666666,
0.5294117647058824,
0.4661921708185055,
0.4573170731707318,
0.21470510063546525,
0.2926813759037974,
0.4603422746712931,
0.5510488031946638,
0.09,
0.01055594651653765,
0.6666666666666666,
-1.0,
],
[
0.39316239316239315,
0.363481443435728,
0.34269662921348315,
0.39316239316239315,
0.39316239316239315,
0.6666666666666666,
0.5882352941176471,
0.4661921708185055,
0.375,
0.17794476526445505,
0.25609592982985585,
0.31011254519919423,
0.41447079003816,
0.12000000000000001,
0.00992258972554539,
0.6666666666666666,
-1.0,
],
[
0.4017094017094017,
0.37893419231070125,
0.3595505617977528,
0.4017094017094017,
0.4017094017094017,
0.6666666666666666,
0.6470588235294117,
0.43060498220640586,
0.301829268292683,
0.24644936815908378,
0.21194949156729978,
0.1474729758069129,
0.17661020532739505,
0.095,
0.00971147079521464,
0.6666666666666666,
-1.0,
],
[
0.4102564102564103,
0.38712146207562764,
0.3707865168539326,
0.4102564102564103,
0.4102564102564103,
0.6666666666666666,
0.7058823529411764,
0.5373665480427048,
0.3292682926829269,
0.09145383816174166,
0.1782908811792736,
0.10567809912365993,
0.39912494586327196,
0.15500000000000003,
0.009781843771991556,
0.6666666666666666,
0.2857142857142857,
],
[
0.4188034188034188,
0.40035987251397137,
0.38764044943820225,
0.4188034188034188,
0.4188034188034188,
0.6666666666666666,
0.7647058823529411,
0.43060498220640586,
0.38414634146341464,
0.16671901804914588,
0.17780307523162106,
0.12481904435081566,
0.4894949171153905,
0.125,
0.009429978888106968,
0.6666666666666666,
0.42857142857142855,
],
[
0.4273504273504274,
0.4107342691832799,
0.398876404494382,
0.4273504273504274,
0.4273504273504274,
0.6666666666666666,
0.8235294117647058,
0.35943060498220647,
0.41158536585365846,
0.22782516249063714,
0.16316889680204447,
0.22620250509980366,
0.3164278966985974,
0.13,
0.007952146375791697,
0.6666666666666666,
0.5714285714285714,
],
[
0.4358974358974359,
0.43059868772385734,
0.42696629213483145,
0.4358974358974359,
0.4358974358974359,
0.6666666666666666,
0.8823529411764706,
0.32384341637010683,
0.426829268292683,
0.24721289293739582,
0.15194936000603573,
0.1801295127701625,
0.21429277824573129,
0.13,
0.007600281491907108,
0.6666666666666666,
0.7142857142857142,
],
[
0.4444444444444445,
0.42823128441833647,
0.4157303370786517,
0.4444444444444445,
0.4444444444444445,
0.6666666666666666,
0.9411764705882353,
0.2882562277580072,
0.5975609756097562,
0.31688211274071565,
0.12024197340861972,
0.09468158796128598,
0.07727144070195302,
0.105,
0.008444757213230118,
0.6666666666666666,
0.8571428571428572,
],
[
0.452991452991453,
0.4431602112975479,
0.43258426966292135,
0.452991452991453,
0.452991452991453,
0.6666666666666666,
1.0,
0.25266903914590755,
-1.0,
0.39799453934810447,
0.0001414661638489788,
0.03743668935364358,
0.027419613352930545,
0.14,
0.00450387051372273,
0.6666666666666666,
1.0,
],
[
0.46153846153846156,
0.4486433350453922,
0.4382022471910112,
0.46153846153846156,
0.46153846153846156,
0.8333333333333334,
0.0,
1.0000000000000002,
0.0274390243902439,
0.0,
0.045607663417779054,
0.0730876530735452,
0.16024130487418112,
0.095,
0.010415200562983815,
0.8333333333333334,
0.0,
],
[
0.47008547008547014,
0.463684509495124,
0.4550561797752809,
0.47008547008547014,
0.47008547008547014,
0.8333333333333334,
0.058823529411764705,
0.8220640569395019,
0.05792682926829271,
0.06368183245946799,
0.08755897491589865,
0.25113911501725356,
0.36928580441210074,
0.11,
0.007741027445460941,
0.8333333333333334,
0.14285714285714285,
],
[
0.47863247863247865,
0.46905198423091704,
0.4606741573033708,
0.47863247863247865,
0.47863247863247865,
0.8333333333333334,
0.1176470588235294,
0.7864768683274024,
0.12195121951219515,
0.08132988619614859,
0.14999813621542551,
0.2996905165894547,
0.6364740024348741,
0.08,
0.007107670654468685,
0.8333333333333334,
-1.0,
],
[
0.4871794871794872,
0.47317112992486215,
0.4606741573033708,
0.4871794871794872,
0.4871794871794872,
0.8333333333333334,
-1.0,
0.7864768683274024,
0.1280487804878049,
0.07948389590934354,
0.16512012059265466,
0.2686786265799859,
0.6328933054607335,
0.08,
0.006896551724137932,
0.8333333333333334,
-1.0,
],
[
0.4957264957264958,
0.47586507161735137,
0.4606741573033708,
0.4957264957264958,
0.4957264957264958,
0.8333333333333334,
-1.0,
0.7864768683274024,
0.13109756097560973,
0.07630898591345106,
0.16512012059265466,
0.3024866706067019,
0.6460225276992488,
0.06,
0.006966924700914849,
0.8333333333333334,
-1.0,
],
[
0.5042735042735044,
0.48720547768144135,
0.47191011235955055,
0.5042735042735044,
0.5042735042735044,
0.8333333333333334,
-1.0,
0.7508896797153027,
0.13414634146341461,
0.07882185227245272,
0.1709737919644853,
0.3240933152854302,
0.5699753443436925,
0.065,
0.006755805770584096,
0.8333333333333334,
-1.0,
],
[
0.5128205128205129,
0.4897837703618793,
0.47191011235955055,
0.5128205128205129,
0.5128205128205129,
0.8333333333333334,
-1.0,
0.7508896797153027,
0.13109756097560973,
0.08157634039674291,
0.17707136631014223,
0.3024866706067019,
0.55735765024434,
0.05500000000000001,
-1.0,
0.8333333333333334,
-1.0,
],
[
0.5213675213675214,
0.5080154969676149,
0.4943820224719101,
0.5213675213675214,
0.5213675213675214,
0.8333333333333334,
-1.0,
0.7508896797153027,
0.14329268292682926,
0.0845579529804045,
0.18341284362962543,
0.33832828119141584,
0.35172333830083996,
0.07,
0.007248416608022519,
0.8333333333333334,
-1.0,
],
[
0.52991452991453,
0.5134714091832119,
0.5,
0.52991452991453,
0.52991452991453,
0.8333333333333334,
-1.0,
0.7508896797153027,
0.1524390243902439,
0.08584821320704569,
0.12780296559723434,
0.27477932625397977,
0.30653835267478063,
0.09,
0.0061928219563687536,
0.8333333333333334,
-1.0,
],
[
0.5384615384615385,
0.5314514291156592,
0.5224719101123595,
0.5384615384615385,
0.5384615384615385,
0.8333333333333334,
-1.0,
0.7153024911032031,
0.1524390243902439,
0.10902940536883562,
0.19268115663502394,
0.3993352779313545,
0.6039067109081672,
0.07,
0.009992962702322309,
0.8333333333333334,
-1.0,
],
[
0.5470085470085471,
0.5371488436799516,
0.5280898876404494,
0.5470085470085471,
0.5470085470085471,
0.8333333333333334,
-1.0,
0.7153024911032031,
0.1524390243902439,
0.09519414308840943,
0.20072995477129107,
0.41077408982009295,
0.596574807580165,
0.105,
0.0061928219563687536,
0.8333333333333334,
-1.0,
],
[
0.5555555555555557,
0.5493089971529934,
0.5449438202247191,
0.5555555555555557,
0.5555555555555557,
0.8333333333333334,
-1.0,
0.7153024911032031,
0.15853658536585366,
0.09882330200304443,
0.20853484993373195,
0.42348388080758015,
0.48352708882515627,
0.09,
0.005348346235045743,
0.8333333333333334,
-1.0,
],
[
0.5641025641025642,
0.557574500073131,
0.550561797752809,
0.5641025641025642,
0.5641025641025642,
0.8333333333333334,
-1.0,
0.7153024911032031,
0.16158536585365854,
0.10281489356561238,
0.21463242427938886,
0.43949821745181405,
0.509615023922466,
0.13,
0.004996481351161154,
0.8333333333333334,
-1.0,
],
[
0.5726495726495727,
0.5654964573986455,
0.5561797752808989,
0.5726495726495727,
0.5726495726495727,
0.8333333333333334,
-1.0,
0.7153024911032031,
0.16463414634146342,
0.10698045279918819,
0.22121780457269832,
0.45271640007880076,
0.596574807580165,
0.065,
0.005207600281491907,
0.8333333333333334,
-1.0,
],
[
0.5811965811965812,
0.5711938719629379,
0.5617977528089888,
0.5811965811965812,
0.5811965811965812,
0.8333333333333334,
-1.0,
0.6797153024911033,
0.1676829268292683,
0.11068209824340977,
0.22731537891835524,
0.4585629039330449,
0.37832280153731257,
0.075,
0.004644616467276566,
0.8333333333333334,
-1.0,
],
[
0.5897435897435899,
0.5852078110703316,
0.5786516853932584,
0.5897435897435899,
0.5897435897435899,
0.8333333333333334,
-1.0,
0.6797153024911033,
0.12195121951219515,
0.11405997052214464,
0.1699981800691802,
0.2752877178934793,
0.24975872922769482,
0.065,
0.004292751583391977,
0.8333333333333334,
-1.0,
],
[
0.5982905982905984,
0.5917147687189832,
0.5842696629213483,
0.5982905982905984,
0.5982905982905984,
0.8333333333333334,
-1.0,
0.6441281138790037,
0.17378048780487806,
0.07403290888443234,
0.2399983335573216,
0.4885580106635147,
0.6259024208921734,
0.095,
0.00422237860661506,
0.8333333333333334,
-1.0,
],
[
0.6068376068376069,
0.6036980472324172,
0.5955056179775281,
0.6068376068376069,
0.6068376068376069,
0.8333333333333334,
0.1764705882352941,
0.6085409252669042,
0.1829268292682927,
0.14164834368279897,
0.32438876250121335,
0.6319244530023704,
0.8306841859370685,
0.07,
0.0035186488388458817,
0.8333333333333334,
-1.0,
],
[
0.6153846153846155,
0.6120587905154204,
0.6067415730337078,
0.6153846153846155,
0.6153846153846155,
0.8333333333333334,
0.23529411764705882,
0.5729537366548044,
0.2439024390243902,
0.1766593374731196,
0.40731577360214744,
0.8274010383899237,
0.9764697055985051,
0.08,
0.003237156931738213,
0.8333333333333334,
-1.0,
],
[
0.623931623931624,
0.6218957594228435,
0.6179775280898876,
0.623931623931624,
0.623931623931624,
0.8333333333333334,
0.2941176470588235,
0.5373665480427048,
0.5060975609756098,
0.19185251407446782,
0.4707305467969794,
0.9318755203070687,
0.99300911543144,
0.095,
0.002674173117522871,
0.8333333333333334,
-1.0,
],
[
0.6324786324786326,
0.6299469715265329,
0.6235955056179775,
0.6324786324786326,
0.6324786324786326,
0.8333333333333334,
0.3529411764705882,
0.5373665480427048,
0.36585365853658536,
0.19037862130620728,
0.5121940523474464,
0.8741730692238767,
1.0,
0.09,
0.00302603800140746,
0.8333333333333334,
-1.0,
],
[
0.6410256410256411,
0.6436309708054273,
0.6404494382022472,
0.6410256410256411,
0.6410256410256411,
0.8333333333333334,
0.4117647058823529,
0.5017793594306051,
0.4573170731707318,
0.21960035760021265,
0.5512185281596508,
0.8352811088021659,
0.9004225222429487,
0.08,
0.0025334271639690367,
0.8333333333333334,
-1.0,
],
[
0.6495726495726497,
0.650389635127367,
0.6460674157303371,
0.6495726495726497,
0.6495726495726497,
0.8333333333333334,
0.47058823529411764,
0.5017793594306051,
0.4573170731707318,
0.24515427549713678,
0.5512185281596508,
0.6868307500683152,
0.8008450444858972,
0.11,
0.002603800140745954,
0.8333333333333334,
-1.0,
],
[
0.6581196581196582,
0.6601415679965169,
0.6573033707865168,
0.6581196581196582,
0.6581196581196582,
0.8333333333333334,
0.5294117647058824,
0.4661921708185055,
0.4817073170731707,
0.2447531833667577,
0.5243892010387603,
0.5162653550162368,
0.6980278885141472,
0.14500000000000002,
0.00274454609429979,
0.8333333333333334,
-1.0,
],
[
0.6666666666666667,
0.6665464823992409,
0.6629213483146067,
0.6666666666666667,
0.6666666666666667,
0.8333333333333334,
0.5882352941176471,
0.4661921708185055,
0.5609756097560976,
0.25764612076255833,
0.4707305467969794,
0.33644214820887275,
0.5328042995645191,
0.09,
0.002463054187192118,
0.8333333333333334,
-1.0,
],
[
0.6752136752136753,
0.6788699050657669,
0.6797752808988764,
0.6752136752136753,
0.6752136752136753,
0.8333333333333334,
0.6470588235294117,
0.4661921708185055,
0.39634146341463417,
0.31621523666851903,
0.3292668219777389,
0.05598790027897992,
0.1067013596417939,
0.115,
0.003237156931738213,
0.8333333333333334,
-1.0,
],
[
0.6837606837606839,
0.6917715727925495,
0.6910112359550562,
0.6837606837606839,
0.6837606837606839,
0.8333333333333334,
0.7058823529411764,
0.5729537366548044,
0.4085365853658537,
0.10700461497571703,
0.2902423461655346,
0.14310589162361226,
0.29698982741040586,
0.125,
0.002463054187192118,
0.8333333333333334,
0.2857142857142857,
],
[
0.6923076923076924,
0.7013534335851533,
0.7022471910112359,
0.6923076923076924,
0.6923076923076924,
0.8333333333333334,
0.7647058823529411,
0.4661921708185055,
0.4969512195121951,
0.1702370309517481,
0.27560816773595803,
0.14910491296970624,
0.3440504162133959,
0.13,
0.002463054187192118,
0.8333333333333334,
0.42857142857142855,
],
[
0.7008547008547009,
0.7074079995101924,
0.7078651685393258,
0.7008547008547009,
0.7008547008547009,
0.8333333333333334,
0.8235294117647058,
0.3950177935943062,
0.4024390243902439,
0.1639017082658806,
0.2392666246358428,
0.13484961139814058,
0.31250618096501487,
0.08,
0.0019704433497536944,
0.8333333333333334,
0.5714285714285714,
],
[
0.7094017094017095,
0.7108774698717316,
0.7078651685393258,
0.7094017094017095,
0.7094017094017095,
0.8333333333333334,
0.8823529411764706,
0.35943060498220647,
0.39634146341463417,
0.21857588131538888,
0.22731537891835524,
0.13039610063612506,
0.20985953437298585,
0.15500000000000003,
-1.0,
0.8333333333333334,
0.7142857142857142,
],
[
0.7179487179487181,
0.7108774698717316,
0.7022471910112359,
0.7179487179487181,
0.7179487179487181,
0.8333333333333334,
0.9411764705882353,
0.32384341637010683,
0.4573170731707318,
0.26124628506535874,
0.17072988899065902,
0.14259749998411278,
0.10329117204737433,
0.09,
-1.0,
0.8333333333333334,
0.8571428571428572,
],
[
0.7264957264957266,
0.7516947682427814,
0.7640449438202247,
0.7264957264957266,
0.7264957264957266,
0.8333333333333334,
1.0,
0.2882562277580072,
-1.0,
0.3312441104694711,
0.00023512490579826907,
0.047782459217458176,
0.03530908235262022,
0.085,
0.0,
0.8333333333333334,
1.0,
],
[
0.7350427350427351,
0.7550962097737021,
0.7640449438202247,
0.7350427350427351,
0.7350427350427351,
0.9999999999999999,
0.0,
-1.0,
0.0,
0.00864039432672098,
0.045607663417779054,
0.0726936495529331,
0.161264361152507,
0.09,
-1.0,
0.9999999999999999,
0.0,
],
[
0.7435897435897437,
0.7653005343664645,
0.7752808988764045,
0.7435897435897437,
0.7435897435897437,
0.9999999999999999,
0.058823529411764705,
-1.0,
0.06097560975609759,
0.06690506680841815,
0.1341444429167175,
0.243767436244511,
0.34200430365674417,
0.06,
-1.0,
0.9999999999999999,
0.14285714285714285,
],
[
0.7521367521367522,
0.7687019758973853,
0.7752808988764045,
0.7521367521367522,
0.7521367521367522,
0.9999999999999999,
0.1176470588235294,
-1.0,
0.12195121951219515,
0.061666706936960886,
0.24633981087680482,
0.3327359731569215,
0.5911185074290938,
0.04,
0.0018296973961998584,
0.9999999999999999,
-1.0,
],
[
0.7606837606837608,
0.7858384383301644,
0.797752808988764,
0.7606837606837608,
0.7606837606837608,
0.9999999999999999,
-1.0,
-1.0,
0.1829268292682927,
0.11659699905767515,
0.28536428668900904,
0.5119440260804912,
0.8622284211854495,
0.045,
0.001337086558761435,
0.9999999999999999,
-1.0,
],
[
0.7692307692307694,
0.7824301939161817,
0.7865168539325842,
0.7692307692307694,
0.7692307692307694,
0.9999999999999999,
-1.0,
-1.0,
0.2439024390243902,
0.09646024113852172,
0.3756083870047315,
0.4725436740192808,
0.7324707832177849,
0.05500000000000001,
-1.0,
0.9999999999999999,
-1.0,
],
[
0.7777777777777779,
0.8062164745419109,
0.8202247191011236,
0.7777777777777779,
0.7777777777777779,
0.9999999999999999,
-1.0,
-1.0,
0.2073170731707317,
0.11115567690337544,
0.4634134575821911,
0.3535800303764005,
0.7502037587087667,
0.06,
0.00154820548909219,
0.9999999999999999,
-1.0,
],
[
0.7863247863247864,
0.8027163912065933,
0.8089887640449438,
0.7863247863247864,
0.7863247863247864,
0.9999999999999999,
-1.0,
-1.0,
0.201219512195122,
0.11461570058230844,
0.49999890365613264,
0.22851568705952632,
0.7278670299653185,
0.75,
-1.0,
0.9999999999999999,
-1.0,
],
[
0.7948717948717949,
0.826526481923039,
0.8426966292134831,
0.7948717948717949,
0.7948717948717949,
0.9999999999999999,
-1.0,
-1.0,
0.17682926829268295,
0.10304201802498372,
0.4829256954882932,
0.22851568705952632,
0.5962337888207231,
0.8,
-1.0,
0.9999999999999999,
-1.0,
],
[
0.8034188034188036,
0.8231250403921182,
0.8314606741573034,
0.8034188034188036,
0.8034188034188036,
0.9999999999999999,
-1.0,
-1.0,
0.1829268292682927,
0.10050982192475896,
0.3341448814542644,
0.3185010072509358,
0.4903474640139954,
0.65,
-1.0,
0.9999999999999999,
-1.0,
],
[
0.8119658119658121,
0.8367308065158015,
0.848314606741573,
0.8119658119658121,
0.8119658119658121,
0.9999999999999999,
-1.0,
-1.0,
0.1829268292682927,
0.10136516297388068,
0.3292668219777389,
0.3370573020926671,
0.5761136820136477,
0.65,
-1.0,
0.9999999999999999,
-1.0,
],
[
0.8205128205128206,
0.8367308065158015,
0.8426966292134831,
0.8205128205128206,
0.8205128205128206,
0.9999999999999999,
-1.0,
-1.0,
0.1829268292682927,
0.11133930944499479,
0.3609742085751549,
0.31646744069293786,
0.16689117068329928,
0.4,
-1.0,
0.9999999999999999,
-1.0,
],
[
0.8290598290598292,
0.8503365726394846,
0.8595505617977528,
0.8290598290598292,
0.8290598290598292,
0.9999999999999999,
-1.0,
-1.0,
0.1829268292682927,
0.11538889023123203,
0.3682912977899432,
0.48576185664626753,
0.1992879528302852,
0.6,
-1.0,
0.9999999999999999,
-1.0,
],
[
0.8376068376068377,
0.8537380141704054,
0.8595505617977528,
0.8376068376068377,
0.8376068376068377,
0.9999999999999999,
-1.0,
-1.0,
0.1829268292682927,
0.12207214825911519,
0.3292668219777389,
0.28443876740447005,
-1.0,
0.6,
-1.0,
0.9999999999999999,
-1.0,
],
[
0.8461538461538463,
0.8707452218250095,
0.8820224719101123,
0.8461538461538463,
0.8461538461538463,
0.9999999999999999,
-1.0,
-1.0,
0.1829268292682927,
0.12593809650373308,
-1.0,
-1.0,
-1.0,
0.5,
-1.0,
0.9999999999999999,
-1.0,
],
[
0.8547008547008548,
0.8741466633559303,
0.8820224719101123,
0.8547008547008548,
0.8547008547008548,
0.9999999999999999,
-1.0,
-1.0,
0.1829268292682927,
0.12980404474835092,
-1.0,
-1.0,
-1.0,
0.15000000000000002,
-1.0,
0.9999999999999999,
-1.0,
],
[
0.8632478632478634,
0.8775481048868511,
0.8820224719101123,
0.8632478632478634,
0.8632478632478634,
0.9999999999999999,
-1.0,
-1.0,
0.1829268292682927,
0.1331867494623916,
-1.0,
-1.0,
-1.0,
0.35,
-1.0,
0.9999999999999999,
-1.0,
],
[
0.8717948717948719,
0.8877524294796135,
0.8932584269662921,
0.8717948717948719,
0.8717948717948719,
0.9999999999999999,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
1.0000000000000002,
-1.0,
0.9999999999999999,
-1.0,
],
[
0.8803418803418804,
0.8843509879486927,
0.8820224719101123,
0.8803418803418804,
0.8803418803418804,
0.9999999999999999,
0.1764705882352941,
-1.0,
-1.0,
-1.0,
0.4414621899378262,
-1.0,
-1.0,
-1.0,
-1.0,
0.9999999999999999,
-1.0,
],
[
0.8888888888888891,
0.8877524294796135,
0.8820224719101123,
0.8888888888888891,
0.8888888888888891,
0.9999999999999999,
0.23529411764705882,
-1.0,
-1.0,
-1.0,
0.9512194052347446,
-1.0,
-1.0,
-1.0,
-1.0,
0.9999999999999999,
-1.0,
],
[
0.8974358974358976,
0.9013581956032967,
0.898876404494382,
0.8974358974358976,
0.8974358974358976,
0.9999999999999999,
0.2941176470588235,
-1.0,
-1.0,
-1.0,
0.8536582157042338,
-1.0,
-1.0,
-1.0,
-1.0,
0.9999999999999999,
-1.0,
],
[
0.9059829059829061,
0.894555312541455,
0.8820224719101123,
0.9059829059829061,
0.9059829059829061,
0.9999999999999999,
0.3529411764705882,
-1.0,
-1.0,
-1.0,
0.9024388104694893,
-1.0,
-1.0,
-1.0,
-1.0,
0.9999999999999999,
-1.0,
],
[
0.9145299145299146,
0.9047596371342175,
0.8932584269662921,
0.9145299145299146,
0.9145299145299146,
0.9999999999999999,
0.4117647058823529,
-1.0,
-1.0,
-1.0,
1.0,
-1.0,
-1.0,
-1.0,
-1.0,
0.9999999999999999,
-1.0,
],
[
0.9230769230769232,
0.9081610786651383,
0.8932584269662921,
0.9230769230769232,
0.9230769230769232,
0.9999999999999999,
0.47058823529411764,
-1.0,
-1.0,
-1.0,
0.8536582157042338,
-1.0,
-1.0,
-1.0,
-1.0,
0.9999999999999999,
-1.0,
],
[
0.9316239316239318,
0.9183654032579007,
0.9044943820224719,
0.9316239316239318,
0.9316239316239318,
0.9999999999999999,
0.5294117647058824,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
0.9999999999999999,
-1.0,
],
[
0.9401709401709403,
0.9217668447888215,
0.9044943820224719,
0.9401709401709403,
0.9401709401709403,
0.9999999999999999,
0.5882352941176471,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
0.9999999999999999,
-1.0,
],
[
0.9487179487179489,
0.965985584690792,
0.9719101123595505,
0.9487179487179489,
0.9487179487179489,
0.9999999999999999,
0.6470588235294117,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
0.9999999999999999,
-1.0,
],
[
0.9572649572649574,
0.9625841431598712,
0.9606741573033708,
0.9572649572649574,
0.9572649572649574,
0.9999999999999999,
0.7058823529411764,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
0.9999999999999999,
0.2857142857142857,
],
[
0.9658119658119659,
0.9795913508144752,
0.9831460674157303,
0.9658119658119659,
0.9658119658119659,
0.9999999999999999,
0.7647058823529411,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
0.9999999999999999,
0.42857142857142855,
],
[
0.9743589743589745,
0.9761899092835544,
0.9719101123595505,
0.9743589743589745,
0.9743589743589745,
0.9999999999999999,
0.8235294117647058,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
0.9999999999999999,
0.5714285714285714,
],
[
0.9829059829059831,
0.9897956754072376,
0.9887640449438202,
0.9829059829059831,
0.9829059829059831,
0.9999999999999999,
0.8823529411764706,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
0.9999999999999999,
0.7142857142857142,
],
[
0.9914529914529915,
1.0,
1.0,
0.9914529914529915,
0.9914529914529915,
0.9999999999999999,
0.9411764705882353,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
0.9999999999999999,
0.8571428571428572,
],
[
1.0000000000000002,
0.9965985584690792,
0.9887640449438202,
1.0000000000000002,
1.0000000000000002,
0.9999999999999999,
1.0,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
-1.0,
0.9999999999999999,
1.0,
],
]
@torch.jit.script
def gaussian(x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
"""Compute the Gaussian distribution probability density."""
pi = 3.14159
a = (2 * pi) ** 0.5
output: torch.Tensor = torch.exp(-0.5 * (((x - mean) / std) ** 2)) / (a * std)
return output
class GaussianLayer(nn.Module):
"""Gaussian pairwise positional embedding layer."""
def __init__(self, k: int = 128, edge_types: int = 1024):
super().__init__()
self.k = k
self.means = nn.Embedding(1, k)
self.stds = nn.Embedding(1, k)
self.mul = nn.Embedding(edge_types, 1)
self.bias = nn.Embedding(edge_types, 1)
nn.init.uniform_(self.means.weight, 0, 3)
nn.init.uniform_(self.stds.weight, 0, 3)
nn.init.constant_(self.bias.weight, 0)
nn.init.constant_(self.mul.weight, 1)
def forward(self, x: torch.Tensor, edge_types: int) -> torch.Tensor:
"""Forward pass to compute the Gaussian pos. embeddings."""
mul = self.mul(edge_types)
bias = self.bias(edge_types)
x = mul * x.unsqueeze(-1) + bias
x = x.expand(-1, -1, -1, self.k)
mean = self.means.weight.float().view(-1)
std = self.stds.weight.float().view(-1).abs() + 1e-5
output: torch.Tensor = gaussian(x.float(), mean, std).type_as(self.means.weight)
return output
class ParallelBlock(nn.Module):
"""Parallel transformer block (MLP & Attention in parallel).
Based on:
'Scaling Vision Atomformers to 22 Billion Parameters` - https://arxiv.org/abs/2302.05442
Adapted from TIMM implementation.
"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: int = 4,
dropout: float = 0.0,
k: int = 128,
gradient_checkpointing: bool = False,
):
super().__init__()
assert (
dim % num_heads == 0
), f"dim {dim} should be divisible by num_heads {num_heads}"
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.mlp_hidden_dim = int(mlp_ratio * dim)
self.proj_drop = nn.Dropout(dropout)
self.attn_drop = nn.Dropout(dropout)
self.gradient_checkpointing = gradient_checkpointing
self.in_proj_in_dim = dim
self.in_proj_out_dim = self.mlp_hidden_dim + 3 * dim
self.out_proj_in_dim = self.mlp_hidden_dim + dim
self.out_proj_out_dim = 2 * dim
self.in_split = [self.mlp_hidden_dim] + [dim] * 3
self.out_split = [dim] * 2
self.in_norm = nn.LayerNorm(dim)
self.q_norm = nn.LayerNorm(self.head_dim)
self.k_norm = nn.LayerNorm(self.head_dim)
self.in_proj = nn.Linear(self.in_proj_in_dim, self.in_proj_out_dim, bias=False)
self.act_fn = nn.GELU()
self.out_proj = nn.Linear(
self.out_proj_in_dim, self.out_proj_out_dim, bias=False
)
self.gaussian_proj = nn.Linear(k, 1)
self.pos_embed_ff_norm = nn.LayerNorm(k)
def forward(
self,
x: torch.Tensor,
pos_embed: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass for the parallel block."""
b, n, c = x.shape
res = x
# Combined MLP fc1 & qkv projections
x = self.in_proj(self.in_norm(x))
x, q, k, v = torch.split(x, self.in_split, dim=-1)
x = self.act_fn(x)
x = self.proj_drop(x)
# Dot product attention
q = self.q_norm(q.view(b, n, self.num_heads, self.head_dim).transpose(1, 2))
k = self.k_norm(k.view(b, n, self.num_heads, self.head_dim).transpose(1, 2))
v = v.view(b, n, self.num_heads, self.head_dim).transpose(1, 2)
x_attn = (
f.scaled_dot_product_attention(
q,
k,
v,
attn_mask=attention_mask
+ self.gaussian_proj(self.pos_embed_ff_norm(pos_embed)).permute(
0, 3, 1, 2
),
is_causal=False,
)
.transpose(1, 2)
.reshape(b, n, c)
)
# Combined MLP fc2 & attn_output projection
x_mlp, x_attn = self.out_proj(torch.cat([x, x_attn], dim=-1)).split(
self.out_split, dim=-1
)
# Residual connections
x = x_mlp + x_attn + res
del x_mlp, x_attn, res
return x, pos_embed
class AtomformerEncoder(nn.Module):
"""Atomformer encoder.
The transformer encoder consists of a series of parallel blocks,
each containing a multi-head self-attention mechanism and a feed-forward network.
"""
def __init__(self, config: AtomformerConfig):
super().__init__()
self.vocab_size = config.vocab_size
self.dim = config.dim
self.num_heads = config.num_heads
self.depth = config.depth
self.mlp_ratio = config.mlp_ratio
self.dropout = config.dropout
self.k = config.k
self.gradient_checkpointing = config.gradient_checkpointing
self.metadata_vocab = nn.Embedding(self.vocab_size, 17)
self.metadata_vocab.weight.requires_grad = False
self.metadata_vocab.weight.fill_(-1)
self.metadata_vocab.weight[1:-4] = torch.tensor(
ATOM_METADATA, dtype=torch.float32
)
self.embed_metadata = nn.Linear(17, self.dim)
self.gaussian_embed = GaussianLayer(
k=self.k, edge_types=(self.vocab_size + 1) ** 2
)
self.embed_tokens = nn.Embedding(config.vocab_size, config.dim)
nn.init.normal_(self.embed_tokens.weight, std=0.02)
self.blocks = nn.ModuleList()
for _ in range(self.depth):
self.blocks.append(
ParallelBlock(
self.dim,
self.num_heads,
self.mlp_ratio,
self.dropout,
self.k,
self.gradient_checkpointing,
)
)
def _expand_mask(
self,
mask: torch.Tensor,
dtype: torch.dtype,
device: torch.device,
tgt_len: Optional[int] = None,
) -> torch.Tensor:
"""
Expand attention mask.
Expands attention_mask from `[bsz, seq_len]` to
`[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = (
mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
)
inverted_mask: torch.Tensor = 1.0 - expanded_mask
return inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(dtype).min
).to(device)
def forward(
self,
input_ids: torch.Tensor,
coords: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass for the transformer encoder."""
# pad coords by zeros for graph token
coords_center = torch.sum(coords, dim=1, keepdim=True) / coords.shape[1]
coords = torch.cat([coords_center, coords], dim=1)
r_ij = torch.cdist(coords, coords, p=2) # [B, N, N]
# pad input_ids by graph token
input_ids = torch.cat(
[
torch.zeros(
input_ids.size(0), 1, dtype=torch.long, device=input_ids.device
).fill_(122),
input_ids,
],
dim=1,
)
edge_type = input_ids.unsqueeze(-1) * self.vocab_size + input_ids.unsqueeze(
-2
) # [B, N, N]
pos_embeds = self.gaussian_embed(r_ij, edge_type) # [B, N, N, K]
input_embeds = self.embed_tokens(input_ids)
atom_metadata = self.metadata_vocab(input_ids)
input_embeds = input_embeds + self.embed_metadata(atom_metadata) # [B, N, C]
attention_mask = (
torch.cat(
[
torch.ones(
attention_mask.size(0),
1,
dtype=torch.bool,
device=attention_mask.device,
),
attention_mask.bool(),
],
dim=1,
)
if attention_mask is not None
else None
)
attention_mask = (
self._expand_mask(attention_mask, input_embeds.dtype, input_embeds.device)
if attention_mask is not None
else None
)
for blk in self.blocks:
input_embeds, pos_embeds = blk(input_embeds, pos_embeds, attention_mask)
return input_embeds, pos_embeds
class AtomformerPreTrainedModel(PreTrainedModel): # type: ignore
"""Base class for all transformer models."""
config_class = AtomformerConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["ParallelBlock"]
def _set_gradient_checkpointing(
self, module: nn.Module, value: bool = False
) -> None:
if isinstance(module, (AtomformerEncoder)):
module.gradient_checkpointing = value
class AtomformerModel(AtomformerPreTrainedModel):
"""Atomformer model for atom modeling."""
def __init__(self, config: AtomformerConfig):
super().__init__(config)
self.config = config
self.encoder = AtomformerEncoder(config)
def forward(
self,
input_ids: torch.Tensor,
coords: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward function call for the transformer model."""
output: torch.Tensor = self.encoder(input_ids, coords, attention_mask)
return output[0][:, :-1]
class AtomformerForMaskedAM(AtomformerPreTrainedModel):
"""Atomformer with an atom modeling head on top for masked atom modeling."""
def __init__(self, config: AtomformerConfig):
super().__init__(config)
self.config = config
self.encoder = AtomformerEncoder(config)
self.am_head = nn.Linear(config.dim, config.vocab_size, bias=False)
def forward(
self,
input_ids: torch.Tensor,
coords: torch.Tensor,
labels: Optional[torch.Tensor] = None,
fixed: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
"""Forward function call for the masked atom modeling model."""
hidden_states = self.encoder(input_ids, coords, attention_mask)
logits = self.am_head(hidden_states)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
logits, labels = logits.view(-1, self.config.vocab_size), labels.view(-1)
loss = loss_fct(logits, labels)
return loss, logits
class AtomformerForCoordinateAM(AtomformerPreTrainedModel):
"""Atomformer with an atom coordinate head on top for coordinate denoising."""
def __init__(self, config: AtomformerConfig):
super().__init__(config)
self.config = config
self.encoder = AtomformerEncoder(config)
self.coords_head = nn.Linear(config.dim, 3)
def forward(
self,
input_ids: torch.Tensor,
coords: torch.Tensor,
labels_coords: Optional[torch.Tensor] = None,
fixed: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
"""Forward function call for the coordinate atom modeling model."""
hidden_states = self.encoder(input_ids, coords, attention_mask)
coords_pred = self.coords_head(hidden_states)
loss = None
if labels_coords is not None:
labels_coords = labels_coords.to(coords_pred.device)
loss_fct = nn.L1Loss()
loss = loss_fct(coords_pred, labels_coords)
return loss, coords_pred
class InitialStructure2RelaxedStructure(AtomformerPreTrainedModel):
"""Atomformer with an coordinate head on top for relaxed structure prediction."""
def __init__(self, config: AtomformerConfig):
super().__init__(config)
self.config = config
self.encoder = AtomformerEncoder(config)
self.coords_head = nn.Linear(config.dim, 3)
def forward(
self,
input_ids: torch.Tensor,
coords: torch.Tensor,
labels_coords: Optional[torch.Tensor] = None,
fixed: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
"""Forward function call.
Initial structure to relaxed structure model.
"""
hidden_states = self.encoder(input_ids, coords, attention_mask)
coords_pred = self.coords_head(hidden_states)
loss = None
if labels_coords is not None:
labels_coords = labels_coords.to(coords_pred.device)
loss_fct = nn.L1Loss()
loss = loss_fct(coords_pred, labels_coords)
return loss, coords_pred
class InitialStructure2RelaxedEnergy(AtomformerPreTrainedModel):
"""Atomformer with an energy head on top for relaxed energy prediction."""
def __init__(self, config: AtomformerConfig):
super().__init__(config)
self.config = config
self.encoder = AtomformerEncoder(config)
self.energy_norm = nn.LayerNorm(config.dim)
self.energy_head = nn.Linear(config.dim, 1, bias=False)
def forward(
self,
input_ids: torch.Tensor,
coords: torch.Tensor,
labels_energy: Optional[torch.Tensor] = None,
fixed: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
"""Forward function call for the relaxed energy prediction model."""
hidden_states = self.encoder(input_ids, coords, attention_mask)
energy = self.energy_head(self.energy_norm(hidden_states[:, 0])).squeeze(-1)
loss = None
if labels_energy is not None:
loss_fct = nn.L1Loss()
loss = loss_fct(energy, labels_energy)
return loss, energy
class InitialStructure2RelaxedStructureAndEnergy(AtomformerPreTrainedModel):
"""Atomformer with an coordinate and energy head."""
def __init__(self, config: AtomformerConfig):
super().__init__(config)
self.config = config
self.encoder = AtomformerEncoder(config)
self.energy_norm = nn.LayerNorm(config.dim)
self.energy_head = nn.Linear(config.dim, 1, bias=False)
self.coords_head = nn.Linear(config.dim, 3)
def forward(
self,
input_ids: torch.Tensor,
coords: torch.Tensor,
labels_coords: Optional[torch.Tensor] = None,
forces: Optional[torch.Tensor] = None,
total_energy: Optional[torch.Tensor] = None,
formation_energy: Optional[torch.Tensor] = None,
has_formation_energy: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Forward function call for the relaxed structure and energy model."""
atom_hidden_states, pos_hidden_states = self.encoder(
input_ids, coords, attention_mask
)
formation_energy_pred = self.formation_energy_head(
self.energy_norm(atom_hidden_states[:, 0])
).squeeze(-1)
loss_formation_energy = None
if formation_energy is not None:
loss_fct = nn.L1Loss()
loss_formation_energy = loss_fct(
formation_energy_pred[has_formation_energy],
formation_energy[has_formation_energy],
)
coords_pred = self.coords_head(atom_hidden_states[:, 1:])
loss_coords = None
if labels_coords is not None:
loss_fct = nn.L1Loss()
loss_coords = loss_fct(coords_pred, labels_coords)
loss = torch.Tensor(0).to(coords.device)
loss = (
loss + loss_formation_energy if loss_formation_energy is not None else loss
)
loss = loss + loss_coords if loss_coords is not None else loss
return loss, (formation_energy_pred, coords_pred)
class Structure2Energy(AtomformerPreTrainedModel):
"""Atomformer with an atom modeling head on top for masked atom modeling."""
def __init__(self, config: AtomformerConfig):
super().__init__(config)
self.config = config
self.encoder = AtomformerEncoder(config)
self.energy_norm = nn.LayerNorm(config.dim)
self.formation_energy_head = nn.Linear(config.dim, 1, bias=False)
def forward(
self,
input_ids: torch.Tensor,
coords: torch.Tensor,
forces: Optional[torch.Tensor] = None,
total_energy: Optional[torch.Tensor] = None,
formation_energy: Optional[torch.Tensor] = None,
has_formation_energy: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[Optional[torch.Tensor], Tuple[torch.Tensor, Optional[torch.Tensor]]]:
"""Forward function call for the structure to energy model."""
atom_hidden_states, pos_hidden_states = self.encoder(
input_ids, coords, attention_mask
)
formation_energy_pred: torch.Tensor = self.formation_energy_head(
self.energy_norm(atom_hidden_states[:, 0])
).squeeze(-1)
loss = torch.Tensor(0).to(coords.device)
if formation_energy is not None:
loss_fct = nn.L1Loss()
loss = loss_fct(
formation_energy_pred[has_formation_energy],
formation_energy[has_formation_energy],
)
return loss, (
formation_energy_pred,
attention_mask.bool() if attention_mask is not None else None,
)
class Structure2Forces(AtomformerPreTrainedModel):
"""Atomformer with a forces head on top for forces prediction."""
def __init__(self, config: AtomformerConfig):
super().__init__(config)
self.config = config
self.encoder = AtomformerEncoder(config)
self.force_norm = nn.LayerNorm(config.dim)
self.force_head = nn.Linear(config.dim, 3)
self.energy_norm = nn.LayerNorm(config.dim)
self.formation_energy_head = nn.Linear(config.dim, 1, bias=False)
def forward(
self,
input_ids: torch.Tensor,
coords: torch.Tensor,
forces: Optional[torch.Tensor] = None,
total_energy: Optional[torch.Tensor] = None,
formation_energy: Optional[torch.Tensor] = None,
has_formation_energy: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
"""Forward function call for the structure to forces model."""
atom_hidden_states, pos_hidden_states = self.encoder(
input_ids, coords, attention_mask
)
attention_mask = attention_mask.bool() if attention_mask is not None else None
forces_pred: torch.Tensor = self.force_head(
self.force_norm(atom_hidden_states[:, 1:])
)
loss = torch.Tensor(0).to(coords.device)
if forces is not None:
loss_fct = nn.L1Loss()
loss = loss_fct(forces_pred[attention_mask], forces[attention_mask])
return loss, (
forces_pred,
attention_mask if attention_mask is not None else None,
)
class Structure2EnergyAndForces(AtomformerPreTrainedModel):
"""Atomformer with an energy and forces head for energy and forces prediction."""
def __init__(self, config: AtomformerConfig):
super().__init__(config)
self.config = config
self.encoder = AtomformerEncoder(config)
self.force_norm = nn.LayerNorm(config.dim)
self.force_head = nn.Linear(config.dim, 3)
self.energy_norm = nn.LayerNorm(config.dim)
self.formation_energy_head = nn.Linear(config.dim, 1, bias=False)
def forward(
self,
input_ids: torch.Tensor,
coords: torch.Tensor,
forces: Optional[torch.Tensor] = None,
total_energy: Optional[torch.Tensor] = None,
formation_energy: Optional[torch.Tensor] = None,
has_formation_energy: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]]:
"""Forward function call for the structure to energy and forces model."""
atom_hidden_states, pos_hidden_states = self.encoder(
input_ids, coords, attention_mask
)
formation_energy_pred: torch.Tensor = self.formation_energy_head(
self.energy_norm(atom_hidden_states[:, 0])
).squeeze(-1)
loss_formation_energy = None
if formation_energy is not None:
loss_fct = nn.L1Loss()
loss_formation_energy = loss_fct(
formation_energy_pred[has_formation_energy],
formation_energy[has_formation_energy],
)
attention_mask = attention_mask.bool() if attention_mask is not None else None
forces_pred: torch.Tensor = self.force_head(
self.force_norm(atom_hidden_states[:, 1:])
)
loss_forces = None
if forces is not None:
loss_fct = nn.L1Loss()
loss_forces = loss_fct(forces_pred[attention_mask], forces[attention_mask])
loss = torch.Tensor(0).to(coords.device)
loss = (
loss + loss_formation_energy if loss_formation_energy is not None else loss
)
loss = loss + loss_forces if loss_forces is not None else loss
return loss, (formation_energy_pred, forces_pred, attention_mask)