jax>=0.2.8 | |
jaxlib>=0.1.59 | |
flax>=0.3.4 | |
optax>=0.0.8 | |
-f https://download.pytorch.org/whl/torch_stable.html | |
torch==1.9.0+cpu | |
-f https://download.pytorch.org/whl/torch_stable.html | |
torchvision==0.10.0+cpu | |
comet_ml==3.12.2 | |
python-dotenv==0.18.0 | |
tqdm | |
transformers |