cc
Browse files- configs/experiment/example.yaml +0 -41
- configs/experiment/miniagent-bert-mlp.yaml +36 -0
- configs/model/miniagent.yaml +3 -0
- configs/model/mnist.yaml +0 -25
- configs/train.yaml +5 -5
- notebooks/test_bert.ipynb +0 -244
- src/eval.py +3 -0
- src/models/components/__init__.py +0 -0
- src/models/components/simple_dense_net.py +0 -54
- src/models/miniagent_module.py +96 -0
- src/models/mlp_module.py +55 -0
- src/models/mnist_module.py +0 -217
- src/train.py +5 -1
- test_bert.ipynb +142 -0
configs/experiment/example.yaml
DELETED
@@ -1,41 +0,0 @@
|
|
1 |
-
# @package _global_
|
2 |
-
|
3 |
-
# to execute this experiment run:
|
4 |
-
# python train.py experiment=example
|
5 |
-
|
6 |
-
defaults:
|
7 |
-
- override /data: mnist
|
8 |
-
- override /model: mnist
|
9 |
-
- override /callbacks: default
|
10 |
-
- override /trainer: default
|
11 |
-
|
12 |
-
# all parameters below will be merged with parameters from default configurations set above
|
13 |
-
# this allows you to overwrite only specified parameters
|
14 |
-
|
15 |
-
tags: ["mnist", "simple_dense_net"]
|
16 |
-
|
17 |
-
seed: 12345
|
18 |
-
|
19 |
-
trainer:
|
20 |
-
min_epochs: 10
|
21 |
-
max_epochs: 10
|
22 |
-
gradient_clip_val: 0.5
|
23 |
-
|
24 |
-
model:
|
25 |
-
optimizer:
|
26 |
-
lr: 0.002
|
27 |
-
net:
|
28 |
-
lin1_size: 128
|
29 |
-
lin2_size: 256
|
30 |
-
lin3_size: 64
|
31 |
-
compile: false
|
32 |
-
|
33 |
-
data:
|
34 |
-
batch_size: 64
|
35 |
-
|
36 |
-
logger:
|
37 |
-
wandb:
|
38 |
-
tags: ${tags}
|
39 |
-
group: "mnist"
|
40 |
-
aim:
|
41 |
-
experiment: "mnist"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/experiment/miniagent-bert-mlp.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- override /data: mixed
|
5 |
+
- override /model: miniagent
|
6 |
+
- override /callbacks: default
|
7 |
+
- override /trainer: gpu
|
8 |
+
|
9 |
+
seed: 42
|
10 |
+
|
11 |
+
model:
|
12 |
+
lr: 0.0001
|
13 |
+
bert_model: bert-base-uncased
|
14 |
+
|
15 |
+
inst_proj_model:
|
16 |
+
_target_: src.models.mlp_module.MLPProjection
|
17 |
+
input_dim: 768
|
18 |
+
hidden_dim: 768
|
19 |
+
output_dim: 768
|
20 |
+
|
21 |
+
tool_proj_model:
|
22 |
+
_target_: src.models.mlp_module.MLPProjection
|
23 |
+
input_dim: 768
|
24 |
+
hidden_dim: 768
|
25 |
+
output_dim: 768
|
26 |
+
|
27 |
+
pred_model:
|
28 |
+
_target_: src.models.mlp_module.MLPPrediction
|
29 |
+
input_dim: 768
|
30 |
+
use_abs_diff: false
|
31 |
+
use_mult: false
|
32 |
+
|
33 |
+
data:
|
34 |
+
bert_model: bert-base-uncased
|
35 |
+
batch_size: 32
|
36 |
+
tool_capacity: 16
|
configs/model/miniagent.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
_target_: src.models.miniagent_module.MiniAgentModule
|
2 |
+
|
3 |
+
bert_model: bert-base-uncased
|
configs/model/mnist.yaml
DELETED
@@ -1,25 +0,0 @@
|
|
1 |
-
_target_: src.models.mnist_module.MNISTLitModule
|
2 |
-
|
3 |
-
optimizer:
|
4 |
-
_target_: torch.optim.Adam
|
5 |
-
_partial_: true
|
6 |
-
lr: 0.001
|
7 |
-
weight_decay: 0.0
|
8 |
-
|
9 |
-
scheduler:
|
10 |
-
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
|
11 |
-
_partial_: true
|
12 |
-
mode: min
|
13 |
-
factor: 0.1
|
14 |
-
patience: 10
|
15 |
-
|
16 |
-
net:
|
17 |
-
_target_: src.models.components.simple_dense_net.SimpleDenseNet
|
18 |
-
input_size: 784
|
19 |
-
lin1_size: 64
|
20 |
-
lin2_size: 128
|
21 |
-
lin3_size: 64
|
22 |
-
output_size: 10
|
23 |
-
|
24 |
-
# compile model for faster training with pytorch 2.0
|
25 |
-
compile: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/train.yaml
CHANGED
@@ -4,11 +4,11 @@
|
|
4 |
# order of defaults determines the order in which configs override each other
|
5 |
defaults:
|
6 |
- _self_
|
7 |
-
- data:
|
8 |
-
- model:
|
9 |
- callbacks: default
|
10 |
-
- logger:
|
11 |
-
- trainer:
|
12 |
- paths: default
|
13 |
- extras: default
|
14 |
- hydra: default
|
@@ -46,4 +46,4 @@ test: True
|
|
46 |
ckpt_path: null
|
47 |
|
48 |
# seed for random number generators in pytorch, numpy and python.random
|
49 |
-
seed:
|
|
|
4 |
# order of defaults determines the order in which configs override each other
|
5 |
defaults:
|
6 |
- _self_
|
7 |
+
- data: mixed
|
8 |
+
- model: miniagent
|
9 |
- callbacks: default
|
10 |
+
- logger: wandb # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
|
11 |
+
- trainer: gpu
|
12 |
- paths: default
|
13 |
- extras: default
|
14 |
- hydra: default
|
|
|
46 |
ckpt_path: null
|
47 |
|
48 |
# seed for random number generators in pytorch, numpy and python.random
|
49 |
+
seed: 42
|
notebooks/test_bert.ipynb
DELETED
@@ -1,244 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "code",
|
5 |
-
"execution_count": 9,
|
6 |
-
"metadata": {},
|
7 |
-
"outputs": [],
|
8 |
-
"source": [
|
9 |
-
"import torch\n",
|
10 |
-
"from transformers import BertModel, BertTokenizer\n",
|
11 |
-
"\n",
|
12 |
-
"model_name = \"bert-base-uncased\"\n",
|
13 |
-
"\n",
|
14 |
-
"tokenizer = BertTokenizer.from_pretrained(model_name)\n",
|
15 |
-
"model = BertModel.from_pretrained(model_name).cuda()"
|
16 |
-
]
|
17 |
-
},
|
18 |
-
{
|
19 |
-
"cell_type": "code",
|
20 |
-
"execution_count": 19,
|
21 |
-
"metadata": {},
|
22 |
-
"outputs": [
|
23 |
-
{
|
24 |
-
"name": "stdout",
|
25 |
-
"output_type": "stream",
|
26 |
-
"text": [
|
27 |
-
"tensor([[ 101, 2182, 2003, 2070, 3793, 2000, 4372, 16044, 102]],\n",
|
28 |
-
" device='cuda:0')\n"
|
29 |
-
]
|
30 |
-
},
|
31 |
-
{
|
32 |
-
"data": {
|
33 |
-
"text/plain": [
|
34 |
-
"BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[-0.0549, 0.1053, -0.1065, ..., -0.3550, 0.0686, 0.6506],\n",
|
35 |
-
" [-0.5759, -0.3650, -0.1383, ..., -0.6782, 0.2092, -0.1639],\n",
|
36 |
-
" [-0.1641, -0.5597, 0.0150, ..., -0.1603, -0.1346, 0.6216],\n",
|
37 |
-
" ...,\n",
|
38 |
-
" [ 0.2448, 0.1254, 0.1587, ..., -0.2749, -0.1163, 0.8809],\n",
|
39 |
-
" [ 0.0481, 0.4950, -0.2827, ..., -0.6097, -0.1212, 0.2527],\n",
|
40 |
-
" [ 0.9046, 0.2137, -0.5897, ..., 0.3040, -0.6172, -0.1950]]],\n",
|
41 |
-
" device='cuda:0'), pooler_output=tensor([[-8.3875e-01, -2.8128e-01, -4.9307e-02, 5.9384e-01, 1.9040e-01,\n",
|
42 |
-
" -1.3149e-01, 8.4180e-01, 1.8583e-01, -3.4094e-03, -9.9988e-01,\n",
|
43 |
-
" 2.5326e-02, 3.7435e-01, 9.6628e-01, -2.6636e-01, 8.5380e-01,\n",
|
44 |
-
" -4.6390e-01, 4.3777e-02, -5.0858e-01, 2.4810e-01, -5.1466e-01,\n",
|
45 |
-
" 4.9031e-01, 9.9296e-01, 5.5225e-01, 2.2509e-01, 3.0543e-01,\n",
|
46 |
-
" 5.2188e-01, -4.6495e-01, 8.9979e-01, 9.2244e-01, 6.7778e-01,\n",
|
47 |
-
" -5.8705e-01, 1.7724e-01, -9.7322e-01, -1.3662e-01, -1.8714e-01,\n",
|
48 |
-
" -9.7217e-01, 2.5207e-01, -7.0827e-01, 2.4030e-04, 2.0537e-02,\n",
|
49 |
-
" -8.6379e-01, 3.1709e-01, 9.9879e-01, -3.7660e-01, 7.7044e-02,\n",
|
50 |
-
" -2.8340e-01, -9.9993e-01, 1.8332e-01, -8.3515e-01, 1.1565e-01,\n",
|
51 |
-
" 5.0053e-02, -4.2741e-02, 1.2181e-01, 3.1478e-01, 3.5848e-01,\n",
|
52 |
-
" 2.1106e-01, -1.7044e-01, 1.2083e-01, -8.6361e-02, -5.1711e-01,\n",
|
53 |
-
" -4.9923e-01, 2.6499e-01, -2.1260e-01, -8.7598e-01, 7.0440e-03,\n",
|
54 |
-
" -2.9079e-01, -1.7843e-02, -2.3522e-01, -1.9458e-02, -8.6378e-02,\n",
|
55 |
-
" 8.0144e-01, 1.1972e-01, 1.0071e-01, -7.2592e-01, -1.2743e-01,\n",
|
56 |
-
" 9.6182e-02, -4.8009e-01, 1.0000e+00, -2.8258e-01, -9.5353e-01,\n",
|
57 |
-
" 1.4182e-01, 6.0877e-02, 4.2072e-01, 5.1125e-01, -3.7690e-01,\n",
|
58 |
-
" -1.0000e+00, 3.7079e-01, -1.1285e-01, -9.7725e-01, 5.0917e-02,\n",
|
59 |
-
" 4.0374e-01, -8.9133e-02, -2.4458e-01, 3.9134e-01, -1.7817e-01,\n",
|
60 |
-
" -1.8376e-01, -2.1744e-01, -6.2693e-02, -1.6466e-01, -1.1514e-01,\n",
|
61 |
-
" 7.3975e-02, -1.4703e-01, -2.8427e-02, -3.2398e-01, 1.3225e-01,\n",
|
62 |
-
" -2.4507e-01, -3.8186e-01, 2.8680e-01, -2.2497e-01, 5.3901e-01,\n",
|
63 |
-
" 3.1464e-01, -2.2506e-01, 1.5128e-01, -9.1770e-01, 5.3941e-01,\n",
|
64 |
-
" -2.4810e-01, -9.6917e-01, -4.1745e-01, -9.7295e-01, 5.8796e-01,\n",
|
65 |
-
" 8.2848e-02, -5.6085e-02, 9.2621e-01, 4.6422e-01, 2.3537e-01,\n",
|
66 |
-
" 5.4746e-02, -2.8870e-01, -1.0000e+00, -2.8243e-01, -3.2437e-01,\n",
|
67 |
-
" 1.5666e-01, -7.7765e-02, -9.6036e-01, -9.2680e-01, 5.7584e-01,\n",
|
68 |
-
" 9.4303e-01, 9.6574e-02, 9.9754e-01, -1.0749e-01, 8.8828e-01,\n",
|
69 |
-
" 1.8555e-01, -1.5084e-01, -2.7866e-01, -3.2091e-01, 5.0980e-01,\n",
|
70 |
-
" 3.2259e-01, -6.5085e-01, 1.1357e-01, 2.4492e-02, 2.6022e-02,\n",
|
71 |
-
" -2.6347e-01, -2.0440e-01, -3.4154e-02, -8.8905e-01, -3.0404e-01,\n",
|
72 |
-
" 8.9583e-01, 5.9022e-02, -1.0150e-02, 6.4437e-01, -1.0694e-01,\n",
|
73 |
-
" -3.0653e-01, 7.8505e-01, 4.7535e-01, 2.1861e-01, 1.8127e-02,\n",
|
74 |
-
" 3.4923e-01, -8.8995e-02, 4.1078e-01, -8.3798e-01, 1.8740e-01,\n",
|
75 |
-
" 3.0741e-01, -1.6962e-01, -9.0649e-03, -9.5766e-01, -2.1796e-01,\n",
|
76 |
-
" 4.8940e-01, 9.7536e-01, 7.1374e-01, 2.6859e-01, 2.7887e-01,\n",
|
77 |
-
" -2.5057e-01, 2.8507e-01, -9.0920e-01, 9.5826e-01, -4.7828e-02,\n",
|
78 |
-
" 1.8272e-01, 2.4859e-01, -1.3071e-02, -8.0613e-01, -4.1807e-01,\n",
|
79 |
-
" 7.6876e-01, -2.1496e-01, -7.9618e-01, 6.0786e-02, -4.1543e-01,\n",
|
80 |
-
" -3.0691e-01, -1.7626e-01, 4.7289e-01, -2.2226e-01, -3.4807e-01,\n",
|
81 |
-
" 1.9591e-02, 8.9078e-01, 9.3557e-01, 6.5839e-01, -4.2014e-01,\n",
|
82 |
-
" 4.0678e-01, -8.4855e-01, -3.4254e-01, 4.9145e-02, 1.6904e-01,\n",
|
83 |
-
" 4.2381e-02, 9.8512e-01, -3.0301e-01, -1.8377e-01, -8.8705e-01,\n",
|
84 |
-
" -9.6226e-01, -7.3838e-02, -8.2263e-01, -1.6584e-02, -6.1115e-01,\n",
|
85 |
-
" 3.1010e-01, 2.1122e-01, -5.6624e-02, 3.1030e-01, -9.6320e-01,\n",
|
86 |
-
" -7.0888e-01, 2.6261e-01, -2.6109e-01, 3.0822e-01, -2.0598e-01,\n",
|
87 |
-
" 6.7753e-01, 1.7591e-01, -5.3198e-01, 5.1929e-01, 8.7353e-01,\n",
|
88 |
-
" 1.1513e-02, -6.4033e-01, 7.1928e-01, -2.0344e-01, 8.7310e-01,\n",
|
89 |
-
" -5.2567e-01, 9.6434e-01, 3.6636e-01, 3.8012e-01, -8.5828e-01,\n",
|
90 |
-
" 6.6280e-02, -8.5531e-01, 1.3883e-01, -1.5715e-01, -5.6898e-01,\n",
|
91 |
-
" 7.5188e-02, 4.8760e-01, 2.2753e-01, 7.3960e-01, -4.1109e-01,\n",
|
92 |
-
" 9.8629e-01, -5.8852e-01, -9.2256e-01, -1.0970e-01, -4.9473e-03,\n",
|
93 |
-
" -9.6941e-01, 2.2550e-01, 1.4062e-01, -4.1525e-02, -3.5449e-01,\n",
|
94 |
-
" -4.7126e-01, -9.2133e-01, 8.0581e-01, 4.2704e-02, 9.7325e-01,\n",
|
95 |
-
" 1.5150e-02, -8.2761e-01, -2.6887e-01, -8.6111e-01, -2.1065e-01,\n",
|
96 |
-
" -5.0521e-02, 4.8313e-01, -2.3626e-01, -8.9746e-01, 3.6525e-01,\n",
|
97 |
-
" 5.3635e-01, 3.2864e-01, 1.7254e-01, 9.8927e-01, 9.9986e-01,\n",
|
98 |
-
" 9.5408e-01, 7.9837e-01, 8.2591e-01, -9.7353e-01, -3.6506e-01,\n",
|
99 |
-
" 9.9994e-01, -6.9189e-01, -9.9999e-01, -9.0511e-01, -4.9764e-01,\n",
|
100 |
-
" 3.5348e-01, -1.0000e+00, -7.4901e-02, 4.5044e-02, -8.3963e-01,\n",
|
101 |
-
" -2.2185e-01, 9.6232e-01, 9.6733e-01, -1.0000e+00, 8.2459e-01,\n",
|
102 |
-
" 9.1231e-01, -4.8469e-01, 3.5679e-01, -2.6093e-01, 9.5376e-01,\n",
|
103 |
-
" 3.0112e-01, 4.6869e-01, -6.6399e-02, 2.9968e-01, -3.6107e-01,\n",
|
104 |
-
" -7.6258e-01, 1.2195e-01, 1.8867e-01, 8.9893e-01, 7.6925e-02,\n",
|
105 |
-
" -6.9103e-01, -8.5666e-01, 6.2071e-02, 5.7222e-02, -3.0562e-01,\n",
|
106 |
-
" -9.2684e-01, -1.4870e-01, -2.9001e-01, 5.3266e-01, 6.7110e-02,\n",
|
107 |
-
" 1.9160e-01, -7.2418e-01, 2.3380e-01, -5.4946e-01, 3.2379e-01,\n",
|
108 |
-
" 5.2877e-01, -9.0674e-01, -4.9937e-01, -4.4697e-02, -4.0249e-01,\n",
|
109 |
-
" 1.7510e-01, -9.4833e-01, 9.4143e-01, -2.0001e-01, 3.7983e-01,\n",
|
110 |
-
" 1.0000e+00, 4.7085e-02, -7.9982e-01, 3.6463e-01, 1.4701e-01,\n",
|
111 |
-
" -4.6405e-01, 1.0000e+00, 4.2325e-01, -9.6105e-01, -3.5190e-01,\n",
|
112 |
-
" 3.0839e-01, -3.3907e-01, -3.5049e-01, 9.9734e-01, -8.8495e-02,\n",
|
113 |
-
" 1.4489e-01, 2.2866e-01, 9.4882e-01, -9.7813e-01, 8.1165e-01,\n",
|
114 |
-
" -8.5137e-01, -9.2814e-01, 9.3252e-01, 8.8071e-01, -1.3548e-01,\n",
|
115 |
-
" -6.6040e-01, -1.6448e-02, -1.4088e-01, 1.9752e-01, -9.2127e-01,\n",
|
116 |
-
" 4.6474e-01, 2.7750e-01, -2.0809e-02, 8.2319e-01, -7.3596e-01,\n",
|
117 |
-
" -4.2808e-01, 3.5810e-01, 3.2605e-02, 2.9415e-01, 3.4245e-01,\n",
|
118 |
-
" 3.8726e-01, -1.9187e-01, -2.6267e-02, -1.0283e-01, -2.9876e-01,\n",
|
119 |
-
" -9.5093e-01, 1.9958e-01, 1.0000e+00, 1.1309e-01, -6.9028e-02,\n",
|
120 |
-
" -9.7448e-02, 1.6147e-02, -2.6026e-01, 3.4048e-01, 4.3058e-01,\n",
|
121 |
-
" -2.0093e-01, -7.9416e-01, 8.8086e-02, -8.1918e-01, -9.6890e-01,\n",
|
122 |
-
" 5.9739e-01, 1.2771e-01, -2.6104e-01, 9.9899e-01, 1.4233e-01,\n",
|
123 |
-
" 1.2059e-01, 4.0483e-02, 5.1663e-01, -2.8824e-02, 4.7300e-01,\n",
|
124 |
-
" -2.7801e-01, 9.5372e-01, -2.0937e-01, 4.0244e-01, 7.6304e-01,\n",
|
125 |
-
" 3.1611e-02, -1.7255e-01, -5.7354e-01, 3.6617e-03, -8.8957e-01,\n",
|
126 |
-
" 2.3939e-02, -9.0446e-01, 9.4090e-01, 2.0094e-02, 2.2362e-01,\n",
|
127 |
-
" 6.5433e-02, 7.9535e-02, 1.0000e+00, -3.0244e-01, 5.7777e-01,\n",
|
128 |
-
" -1.9970e-01, 7.3194e-01, -9.7151e-01, -6.7480e-01, -3.0641e-01,\n",
|
129 |
-
" 4.5864e-02, 2.2067e-01, -1.9735e-01, 1.2545e-01, -9.4334e-01,\n",
|
130 |
-
" 2.7122e-04, -1.1515e-01, -9.5089e-01, -9.7541e-01, 4.2593e-01,\n",
|
131 |
-
" 6.3330e-01, -5.9799e-02, -7.2403e-01, -5.1824e-01, -5.4331e-01,\n",
|
132 |
-
" 3.2069e-01, -1.8098e-01, -8.6600e-01, 4.9465e-01, -2.0962e-01,\n",
|
133 |
-
" 3.1231e-01, -1.9074e-01, 4.5443e-01, -7.5650e-02, 8.2584e-01,\n",
|
134 |
-
" -1.6397e-01, -1.7857e-02, -5.9298e-02, -7.1197e-01, 6.4738e-01,\n",
|
135 |
-
" -7.2854e-01, -2.4250e-01, -8.7492e-02, 1.0000e+00, -3.0222e-01,\n",
|
136 |
-
" 1.1153e-01, 6.3838e-01, 5.6054e-01, -9.7328e-02, 1.4136e-01,\n",
|
137 |
-
" 2.4002e-01, 7.8386e-02, 4.0799e-01, 1.1524e-01, -4.4807e-01,\n",
|
138 |
-
" -1.6638e-01, 4.3632e-01, -1.1570e-01, -1.5105e-01, 6.8226e-01,\n",
|
139 |
-
" 5.2354e-01, 1.7305e-03, 2.1510e-02, 4.1700e-03, 9.9460e-01,\n",
|
140 |
-
" 1.7238e-02, -1.2763e-01, -3.8213e-01, 2.8577e-02, -2.2359e-01,\n",
|
141 |
-
" -3.2380e-01, 1.0000e+00, 2.7486e-01, 2.9497e-02, -9.7830e-01,\n",
|
142 |
-
" -1.3164e-01, -8.6022e-01, 9.9958e-01, 6.8541e-01, -7.7604e-01,\n",
|
143 |
-
" 4.9727e-01, 2.5391e-01, -1.3586e-01, 5.6488e-01, -1.2485e-01,\n",
|
144 |
-
" -2.2596e-01, 8.4730e-02, 9.8264e-03, 9.0287e-01, -3.3443e-01,\n",
|
145 |
-
" -9.3577e-01, -5.7664e-01, 3.0516e-01, -9.2533e-01, 9.8399e-01,\n",
|
146 |
-
" -5.3663e-01, -1.1600e-01, -2.4484e-01, 1.1554e-01, 4.8479e-01,\n",
|
147 |
-
" -1.2667e-01, -9.4730e-01, -1.2690e-01, -8.1969e-03, 9.2896e-01,\n",
|
148 |
-
" 8.2588e-02, -3.9092e-01, -8.7582e-01, -1.1210e-01, 1.4961e-01,\n",
|
149 |
-
" 3.9152e-02, -8.8999e-01, 9.4467e-01, -9.5203e-01, 2.1941e-01,\n",
|
150 |
-
" 9.9999e-01, 2.9063e-01, -4.6173e-01, 8.3155e-02, -2.9862e-01,\n",
|
151 |
-
" 2.1412e-01, 1.2559e-01, 4.1466e-01, -9.2596e-01, -1.5259e-01,\n",
|
152 |
-
" -1.1740e-01, 1.7836e-01, -7.1510e-02, 6.5254e-02, 5.5836e-01,\n",
|
153 |
-
" 1.7951e-01, -3.6205e-01, -4.6712e-01, -6.8167e-02, 2.7963e-01,\n",
|
154 |
-
" 6.4896e-01, -1.6497e-01, -1.5638e-01, 1.0768e-01, -1.3088e-01,\n",
|
155 |
-
" -7.8675e-01, -1.7540e-01, -2.6264e-01, -9.9611e-01, 5.5635e-01,\n",
|
156 |
-
" -1.0000e+00, -1.6508e-01, -4.3545e-01, -1.5236e-01, 7.6126e-01,\n",
|
157 |
-
" 3.3181e-01, 1.6770e-01, -6.2815e-01, 2.1790e-01, 7.3190e-01,\n",
|
158 |
-
" 6.5308e-01, -1.4397e-01, 9.8390e-02, -6.1118e-01, 1.8836e-01,\n",
|
159 |
-
" -9.3982e-02, 2.3265e-01, 1.1375e-01, 7.3512e-01, -8.6529e-02,\n",
|
160 |
-
" 1.0000e+00, 4.7590e-02, -4.0572e-01, -9.1862e-01, 2.3070e-01,\n",
|
161 |
-
" -1.1442e-01, 9.9996e-01, -8.1991e-01, -9.1503e-01, 2.0697e-01,\n",
|
162 |
-
" -5.1682e-01, -7.5368e-01, 2.1045e-01, 5.6595e-03, -6.4855e-01,\n",
|
163 |
-
" -4.1527e-01, 9.1653e-01, 7.3400e-01, -4.2874e-01, 2.3967e-01,\n",
|
164 |
-
" -2.4068e-01, -3.3491e-01, -4.3406e-02, -2.8789e-02, 9.7527e-01,\n",
|
165 |
-
" 3.9052e-01, 8.5009e-01, 4.9386e-01, -7.6392e-03, 9.4010e-01,\n",
|
166 |
-
" 2.0048e-01, 2.7198e-01, 6.1171e-02, 1.0000e+00, 2.8885e-01,\n",
|
167 |
-
" -9.1301e-01, 2.4983e-01, -9.7476e-01, -1.2421e-01, -9.2382e-01,\n",
|
168 |
-
" 1.9579e-01, 8.3662e-02, 8.1386e-01, -2.0425e-01, 9.2494e-01,\n",
|
169 |
-
" 5.8940e-02, 2.2477e-02, 1.0070e-01, 3.0503e-01, 2.4897e-01,\n",
|
170 |
-
" -8.5819e-01, -9.6919e-01, -9.6845e-01, 3.7154e-01, -3.5970e-01,\n",
|
171 |
-
" -1.5696e-02, 2.3336e-01, 6.0670e-02, 3.0880e-01, 2.7858e-01,\n",
|
172 |
-
" -1.0000e+00, 8.9189e-01, 2.8790e-01, -9.3187e-02, 9.3518e-01,\n",
|
173 |
-
" 2.2574e-01, 3.3557e-01, 1.7759e-01, -9.6978e-01, -9.1210e-01,\n",
|
174 |
-
" -2.6055e-01, -3.0031e-01, 7.1668e-01, 5.6109e-01, 7.4522e-01,\n",
|
175 |
-
" 2.2312e-01, -4.3589e-01, -3.5411e-01, 2.3049e-01, -4.5362e-01,\n",
|
176 |
-
" -9.8112e-01, 3.6383e-01, 7.7364e-02, -9.1412e-01, 9.2768e-01,\n",
|
177 |
-
" -4.3448e-01, -7.7045e-02, 5.9378e-01, -1.9775e-01, 8.6886e-01,\n",
|
178 |
-
" 6.9895e-01, 3.6932e-01, 3.9024e-02, 4.8328e-01, 7.9945e-01,\n",
|
179 |
-
" 9.2656e-01, 9.7223e-01, -6.6530e-02, 6.7691e-01, 1.3522e-01,\n",
|
180 |
-
" 3.0958e-01, 5.4615e-01, -9.1295e-01, 8.0810e-02, 1.6158e-01,\n",
|
181 |
-
" 6.6348e-02, 1.2711e-01, -1.4847e-01, -9.2298e-01, 4.7369e-01,\n",
|
182 |
-
" -1.1713e-01, 4.0759e-01, -3.2167e-01, 1.7547e-01, -3.2665e-01,\n",
|
183 |
-
" -1.0958e-01, -7.2304e-01, -3.8936e-01, 4.9849e-01, 1.8104e-01,\n",
|
184 |
-
" 8.7080e-01, 4.7525e-01, 5.3199e-02, -5.8020e-01, -7.1312e-02,\n",
|
185 |
-
" 2.1172e-01, -8.8360e-01, 8.8993e-01, 9.6810e-02, 4.6417e-01,\n",
|
186 |
-
" 8.0528e-02, -1.0677e-01, 7.8678e-01, -2.3742e-01, -3.0270e-01,\n",
|
187 |
-
" -1.4502e-01, -7.0351e-01, 7.5302e-01, -3.6792e-01, -4.0656e-01,\n",
|
188 |
-
" -2.9551e-01, 5.9936e-01, 2.2145e-01, 9.9465e-01, 3.4475e-02,\n",
|
189 |
-
" -5.7931e-02, -2.5533e-01, -2.1419e-01, 3.0497e-01, -2.0330e-01,\n",
|
190 |
-
" -1.0000e+00, 3.2221e-01, -2.2759e-02, -1.8658e-02, -9.2918e-02,\n",
|
191 |
-
" 2.2858e-02, -1.0043e-01, -9.5813e-01, -1.5660e-01, 2.8550e-01,\n",
|
192 |
-
" -2.3666e-02, -4.3055e-01, -1.8355e-01, 3.7819e-01, 4.1632e-01,\n",
|
193 |
-
" 5.1126e-01, 8.4217e-01, 3.9612e-02, 4.3960e-01, 4.4244e-01,\n",
|
194 |
-
" -2.3278e-01, -5.7333e-01, 8.3078e-01]], device='cuda:0'), hidden_states=None, past_key_values=None, attentions=None, cross_attentions=None)"
|
195 |
-
]
|
196 |
-
},
|
197 |
-
"execution_count": 19,
|
198 |
-
"metadata": {},
|
199 |
-
"output_type": "execute_result"
|
200 |
-
}
|
201 |
-
],
|
202 |
-
"source": [
|
203 |
-
"input_text = \"Here is some text to encode\"\n",
|
204 |
-
"input_ids = tokenizer.encode(input_text, add_special_tokens=True)\n",
|
205 |
-
"input_ids = torch.tensor([input_ids]).cuda()\n",
|
206 |
-
"\n",
|
207 |
-
"print(input_ids)\n",
|
208 |
-
"\n",
|
209 |
-
"with torch.no_grad():\n",
|
210 |
-
" last_hidden_states = model(input_ids) # Models outputs are now tuples\n",
|
211 |
-
"\n",
|
212 |
-
"last_hidden_states"
|
213 |
-
]
|
214 |
-
},
|
215 |
-
{
|
216 |
-
"cell_type": "code",
|
217 |
-
"execution_count": null,
|
218 |
-
"metadata": {},
|
219 |
-
"outputs": [],
|
220 |
-
"source": []
|
221 |
-
}
|
222 |
-
],
|
223 |
-
"metadata": {
|
224 |
-
"kernelspec": {
|
225 |
-
"display_name": "swim",
|
226 |
-
"language": "python",
|
227 |
-
"name": "python3"
|
228 |
-
},
|
229 |
-
"language_info": {
|
230 |
-
"codemirror_mode": {
|
231 |
-
"name": "ipython",
|
232 |
-
"version": 3
|
233 |
-
},
|
234 |
-
"file_extension": ".py",
|
235 |
-
"mimetype": "text/x-python",
|
236 |
-
"name": "python",
|
237 |
-
"nbconvert_exporter": "python",
|
238 |
-
"pygments_lexer": "ipython3",
|
239 |
-
"version": "3.12.6"
|
240 |
-
}
|
241 |
-
},
|
242 |
-
"nbformat": 4,
|
243 |
-
"nbformat_minor": 2
|
244 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/eval.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
from typing import Any, Dict, List, Tuple
|
2 |
|
|
|
3 |
import hydra
|
4 |
import rootutils
|
5 |
from lightning import LightningDataModule, LightningModule, Trainer
|
@@ -34,6 +35,8 @@ from src.utils import (
|
|
34 |
|
35 |
log = RankedLogger(__name__, rank_zero_only=True)
|
36 |
|
|
|
|
|
37 |
|
38 |
@task_wrapper
|
39 |
def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
|
1 |
from typing import Any, Dict, List, Tuple
|
2 |
|
3 |
+
import torch
|
4 |
import hydra
|
5 |
import rootutils
|
6 |
from lightning import LightningDataModule, LightningModule, Trainer
|
|
|
35 |
|
36 |
log = RankedLogger(__name__, rank_zero_only=True)
|
37 |
|
38 |
+
torch.set_float32_matmul_precision("medium")
|
39 |
+
|
40 |
|
41 |
@task_wrapper
|
42 |
def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
src/models/components/__init__.py
DELETED
File without changes
|
src/models/components/simple_dense_net.py
DELETED
@@ -1,54 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from torch import nn
|
3 |
-
|
4 |
-
|
5 |
-
class SimpleDenseNet(nn.Module):
|
6 |
-
"""A simple fully-connected neural net for computing predictions."""
|
7 |
-
|
8 |
-
def __init__(
|
9 |
-
self,
|
10 |
-
input_size: int = 784,
|
11 |
-
lin1_size: int = 256,
|
12 |
-
lin2_size: int = 256,
|
13 |
-
lin3_size: int = 256,
|
14 |
-
output_size: int = 10,
|
15 |
-
) -> None:
|
16 |
-
"""Initialize a `SimpleDenseNet` module.
|
17 |
-
|
18 |
-
:param input_size: The number of input features.
|
19 |
-
:param lin1_size: The number of output features of the first linear layer.
|
20 |
-
:param lin2_size: The number of output features of the second linear layer.
|
21 |
-
:param lin3_size: The number of output features of the third linear layer.
|
22 |
-
:param output_size: The number of output features of the final linear layer.
|
23 |
-
"""
|
24 |
-
super().__init__()
|
25 |
-
|
26 |
-
self.model = nn.Sequential(
|
27 |
-
nn.Linear(input_size, lin1_size),
|
28 |
-
nn.BatchNorm1d(lin1_size),
|
29 |
-
nn.ReLU(),
|
30 |
-
nn.Linear(lin1_size, lin2_size),
|
31 |
-
nn.BatchNorm1d(lin2_size),
|
32 |
-
nn.ReLU(),
|
33 |
-
nn.Linear(lin2_size, lin3_size),
|
34 |
-
nn.BatchNorm1d(lin3_size),
|
35 |
-
nn.ReLU(),
|
36 |
-
nn.Linear(lin3_size, output_size),
|
37 |
-
)
|
38 |
-
|
39 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
40 |
-
"""Perform a single forward pass through the network.
|
41 |
-
|
42 |
-
:param x: The input tensor.
|
43 |
-
:return: A tensor of predictions.
|
44 |
-
"""
|
45 |
-
batch_size, channels, width, height = x.size()
|
46 |
-
|
47 |
-
# (batch, 1, width, height) -> (batch, 1*width*height)
|
48 |
-
x = x.view(batch_size, -1)
|
49 |
-
|
50 |
-
return self.model(x)
|
51 |
-
|
52 |
-
|
53 |
-
if __name__ == "__main__":
|
54 |
-
_ = SimpleDenseNet()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/miniagent_module.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from lightning import LightningModule
|
7 |
+
from torchmetrics import MaxMetric, MeanMetric
|
8 |
+
from torchmetrics.classification.accuracy import Accuracy
|
9 |
+
|
10 |
+
from transformers import BertModel
|
11 |
+
|
12 |
+
|
13 |
+
class MiniAgentModule(LightningModule):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
bert_model: str,
|
17 |
+
inst_proj_model: nn.Module,
|
18 |
+
tool_proj_model: nn.Module,
|
19 |
+
pred_model: nn.Module,
|
20 |
+
lr: float,
|
21 |
+
) -> None:
|
22 |
+
super().__init__()
|
23 |
+
|
24 |
+
self.save_hyperparameters(
|
25 |
+
logger=False, ignore=["inst_proj_model", "tool_proj_model", "pred_model"]
|
26 |
+
)
|
27 |
+
|
28 |
+
self.bert_model = BertModel.from_pretrained(bert_model)
|
29 |
+
self.bert_model.eval()
|
30 |
+
self.bert_model.requires_grad_(False)
|
31 |
+
|
32 |
+
self.inst_proj_model = inst_proj_model
|
33 |
+
self.tool_proj_model = tool_proj_model
|
34 |
+
self.pred_model = pred_model
|
35 |
+
|
36 |
+
self.lr = lr
|
37 |
+
|
38 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
39 |
+
pass
|
40 |
+
|
41 |
+
def on_train_start(self) -> None:
|
42 |
+
pass
|
43 |
+
|
44 |
+
def training_step(
|
45 |
+
self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
|
46 |
+
) -> torch.Tensor:
|
47 |
+
B = batch["inst_ids"].shape[0]
|
48 |
+
|
49 |
+
inst_ids = batch["inst_ids"]
|
50 |
+
inst_mask = batch["inst_mask"]
|
51 |
+
tool_ids = batch["tool_desc_ids"]
|
52 |
+
tool_mask = batch["tool_desc_mask"]
|
53 |
+
|
54 |
+
inst_z = self.bert_model(inst_ids, inst_mask, return_dict=False)[1]
|
55 |
+
tool_z = self.bert_model(tool_ids, tool_mask, return_dict=False)[1]
|
56 |
+
|
57 |
+
inst_emb = self.inst_proj_model(inst_z)
|
58 |
+
tool_emb = self.tool_proj_model(tool_z)
|
59 |
+
|
60 |
+
inst_emb_r = inst_emb.unsqueeze(0).repeat(B, 1, 1).view(B * B, -1)
|
61 |
+
tool_emb_r = tool_emb.unsqueeze(1).repeat(1, B, 1).view(B * B, -1)
|
62 |
+
|
63 |
+
pred = self.pred_model(inst_emb_r, tool_emb_r) # [BxB, 1]
|
64 |
+
pred = pred.view(B, B) # [B, B]
|
65 |
+
|
66 |
+
# mask out the diagonal
|
67 |
+
target = torch.eye(B, device=pred.device).float()
|
68 |
+
|
69 |
+
loss = F.binary_cross_entropy_with_logits(pred, target)
|
70 |
+
|
71 |
+
self.log("train/loss", loss, on_step=True, sync_dist=True, prog_bar=True)
|
72 |
+
|
73 |
+
return loss
|
74 |
+
|
75 |
+
def on_train_epoch_end(self) -> None:
|
76 |
+
pass
|
77 |
+
|
78 |
+
def validation_step(
|
79 |
+
self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
|
80 |
+
) -> None:
|
81 |
+
pass
|
82 |
+
|
83 |
+
def on_validation_epoch_end(self) -> None:
|
84 |
+
pass
|
85 |
+
|
86 |
+
def test_step(
|
87 |
+
self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
|
88 |
+
) -> None:
|
89 |
+
pass
|
90 |
+
|
91 |
+
def on_test_epoch_end(self) -> None:
|
92 |
+
pass
|
93 |
+
|
94 |
+
def configure_optimizers(self):
|
95 |
+
opt = torch.optim.AdamW(self.parameters(), lr=self.lr)
|
96 |
+
return opt
|
src/models/mlp_module.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class MLPProjection(nn.Module):
|
7 |
+
def __init__(self, input_dim, hidden_dim, output_dim):
|
8 |
+
super().__init__()
|
9 |
+
self.linear1 = nn.Linear(input_dim, hidden_dim)
|
10 |
+
self.linear2 = nn.Linear(hidden_dim, output_dim)
|
11 |
+
|
12 |
+
def forward(self, x_output):
|
13 |
+
# only use first token ([CLS]) of each output
|
14 |
+
x = x_output
|
15 |
+
|
16 |
+
x = self.linear1(x)
|
17 |
+
x = F.silu(x)
|
18 |
+
x = self.linear2(x)
|
19 |
+
|
20 |
+
return x
|
21 |
+
|
22 |
+
|
23 |
+
class MLPPrediction(nn.Module):
|
24 |
+
def __init__(self, input_dim, use_abs_diff=False, use_mult=False):
|
25 |
+
super().__init__()
|
26 |
+
|
27 |
+
self.use_abs_diff = use_abs_diff
|
28 |
+
self.use_mult = use_mult
|
29 |
+
|
30 |
+
real_input_dim = input_dim * (2 + int(use_abs_diff) + int(use_mult))
|
31 |
+
|
32 |
+
self.mlp = nn.Sequential(
|
33 |
+
nn.Linear(real_input_dim, 1024),
|
34 |
+
nn.SiLU(),
|
35 |
+
nn.Linear(1024, 512),
|
36 |
+
nn.SiLU(),
|
37 |
+
nn.Linear(512, 256),
|
38 |
+
nn.SiLU(),
|
39 |
+
nn.Linear(256, 1),
|
40 |
+
)
|
41 |
+
|
42 |
+
def forward(self, x1, x2):
|
43 |
+
x = torch.cat([x1, x2], dim=1)
|
44 |
+
|
45 |
+
if self.use_abs_diff:
|
46 |
+
x_diff = torch.abs(x1 - x2)
|
47 |
+
x = torch.cat([x, x_diff], dim=1)
|
48 |
+
|
49 |
+
if self.use_mult:
|
50 |
+
x_mult = x1 * x2
|
51 |
+
x = torch.cat([x, x_mult], dim=1)
|
52 |
+
|
53 |
+
x = self.mlp(x)
|
54 |
+
|
55 |
+
return x
|
src/models/mnist_module.py
DELETED
@@ -1,217 +0,0 @@
|
|
1 |
-
from typing import Any, Dict, Tuple
|
2 |
-
|
3 |
-
import torch
|
4 |
-
from lightning import LightningModule
|
5 |
-
from torchmetrics import MaxMetric, MeanMetric
|
6 |
-
from torchmetrics.classification.accuracy import Accuracy
|
7 |
-
|
8 |
-
|
9 |
-
class MNISTLitModule(LightningModule):
|
10 |
-
"""Example of a `LightningModule` for MNIST classification.
|
11 |
-
|
12 |
-
A `LightningModule` implements 8 key methods:
|
13 |
-
|
14 |
-
```python
|
15 |
-
def __init__(self):
|
16 |
-
# Define initialization code here.
|
17 |
-
|
18 |
-
def setup(self, stage):
|
19 |
-
# Things to setup before each stage, 'fit', 'validate', 'test', 'predict'.
|
20 |
-
# This hook is called on every process when using DDP.
|
21 |
-
|
22 |
-
def training_step(self, batch, batch_idx):
|
23 |
-
# The complete training step.
|
24 |
-
|
25 |
-
def validation_step(self, batch, batch_idx):
|
26 |
-
# The complete validation step.
|
27 |
-
|
28 |
-
def test_step(self, batch, batch_idx):
|
29 |
-
# The complete test step.
|
30 |
-
|
31 |
-
def predict_step(self, batch, batch_idx):
|
32 |
-
# The complete predict step.
|
33 |
-
|
34 |
-
def configure_optimizers(self):
|
35 |
-
# Define and configure optimizers and LR schedulers.
|
36 |
-
```
|
37 |
-
|
38 |
-
Docs:
|
39 |
-
https://lightning.ai/docs/pytorch/latest/common/lightning_module.html
|
40 |
-
"""
|
41 |
-
|
42 |
-
def __init__(
|
43 |
-
self,
|
44 |
-
net: torch.nn.Module,
|
45 |
-
optimizer: torch.optim.Optimizer,
|
46 |
-
scheduler: torch.optim.lr_scheduler,
|
47 |
-
compile: bool,
|
48 |
-
) -> None:
|
49 |
-
"""Initialize a `MNISTLitModule`.
|
50 |
-
|
51 |
-
:param net: The model to train.
|
52 |
-
:param optimizer: The optimizer to use for training.
|
53 |
-
:param scheduler: The learning rate scheduler to use for training.
|
54 |
-
"""
|
55 |
-
super().__init__()
|
56 |
-
|
57 |
-
# this line allows to access init params with 'self.hparams' attribute
|
58 |
-
# also ensures init params will be stored in ckpt
|
59 |
-
self.save_hyperparameters(logger=False)
|
60 |
-
|
61 |
-
self.net = net
|
62 |
-
|
63 |
-
# loss function
|
64 |
-
self.criterion = torch.nn.CrossEntropyLoss()
|
65 |
-
|
66 |
-
# metric objects for calculating and averaging accuracy across batches
|
67 |
-
self.train_acc = Accuracy(task="multiclass", num_classes=10)
|
68 |
-
self.val_acc = Accuracy(task="multiclass", num_classes=10)
|
69 |
-
self.test_acc = Accuracy(task="multiclass", num_classes=10)
|
70 |
-
|
71 |
-
# for averaging loss across batches
|
72 |
-
self.train_loss = MeanMetric()
|
73 |
-
self.val_loss = MeanMetric()
|
74 |
-
self.test_loss = MeanMetric()
|
75 |
-
|
76 |
-
# for tracking best so far validation accuracy
|
77 |
-
self.val_acc_best = MaxMetric()
|
78 |
-
|
79 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
80 |
-
"""Perform a forward pass through the model `self.net`.
|
81 |
-
|
82 |
-
:param x: A tensor of images.
|
83 |
-
:return: A tensor of logits.
|
84 |
-
"""
|
85 |
-
return self.net(x)
|
86 |
-
|
87 |
-
def on_train_start(self) -> None:
|
88 |
-
"""Lightning hook that is called when training begins."""
|
89 |
-
# by default lightning executes validation step sanity checks before training starts,
|
90 |
-
# so it's worth to make sure validation metrics don't store results from these checks
|
91 |
-
self.val_loss.reset()
|
92 |
-
self.val_acc.reset()
|
93 |
-
self.val_acc_best.reset()
|
94 |
-
|
95 |
-
def model_step(
|
96 |
-
self, batch: Tuple[torch.Tensor, torch.Tensor]
|
97 |
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
98 |
-
"""Perform a single model step on a batch of data.
|
99 |
-
|
100 |
-
:param batch: A batch of data (a tuple) containing the input tensor of images and target labels.
|
101 |
-
|
102 |
-
:return: A tuple containing (in order):
|
103 |
-
- A tensor of losses.
|
104 |
-
- A tensor of predictions.
|
105 |
-
- A tensor of target labels.
|
106 |
-
"""
|
107 |
-
x, y = batch
|
108 |
-
logits = self.forward(x)
|
109 |
-
loss = self.criterion(logits, y)
|
110 |
-
preds = torch.argmax(logits, dim=1)
|
111 |
-
return loss, preds, y
|
112 |
-
|
113 |
-
def training_step(
|
114 |
-
self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
|
115 |
-
) -> torch.Tensor:
|
116 |
-
"""Perform a single training step on a batch of data from the training set.
|
117 |
-
|
118 |
-
:param batch: A batch of data (a tuple) containing the input tensor of images and target
|
119 |
-
labels.
|
120 |
-
:param batch_idx: The index of the current batch.
|
121 |
-
:return: A tensor of losses between model predictions and targets.
|
122 |
-
"""
|
123 |
-
loss, preds, targets = self.model_step(batch)
|
124 |
-
|
125 |
-
# update and log metrics
|
126 |
-
self.train_loss(loss)
|
127 |
-
self.train_acc(preds, targets)
|
128 |
-
self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True)
|
129 |
-
self.log("train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True)
|
130 |
-
|
131 |
-
# return loss or backpropagation will fail
|
132 |
-
return loss
|
133 |
-
|
134 |
-
def on_train_epoch_end(self) -> None:
|
135 |
-
"Lightning hook that is called when a training epoch ends."
|
136 |
-
pass
|
137 |
-
|
138 |
-
def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
|
139 |
-
"""Perform a single validation step on a batch of data from the validation set.
|
140 |
-
|
141 |
-
:param batch: A batch of data (a tuple) containing the input tensor of images and target
|
142 |
-
labels.
|
143 |
-
:param batch_idx: The index of the current batch.
|
144 |
-
"""
|
145 |
-
loss, preds, targets = self.model_step(batch)
|
146 |
-
|
147 |
-
# update and log metrics
|
148 |
-
self.val_loss(loss)
|
149 |
-
self.val_acc(preds, targets)
|
150 |
-
self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)
|
151 |
-
self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
|
152 |
-
|
153 |
-
def on_validation_epoch_end(self) -> None:
|
154 |
-
"Lightning hook that is called when a validation epoch ends."
|
155 |
-
acc = self.val_acc.compute() # get current val acc
|
156 |
-
self.val_acc_best(acc) # update best so far val acc
|
157 |
-
# log `val_acc_best` as a value through `.compute()` method, instead of as a metric object
|
158 |
-
# otherwise metric would be reset by lightning after each epoch
|
159 |
-
self.log("val/acc_best", self.val_acc_best.compute(), sync_dist=True, prog_bar=True)
|
160 |
-
|
161 |
-
def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
|
162 |
-
"""Perform a single test step on a batch of data from the test set.
|
163 |
-
|
164 |
-
:param batch: A batch of data (a tuple) containing the input tensor of images and target
|
165 |
-
labels.
|
166 |
-
:param batch_idx: The index of the current batch.
|
167 |
-
"""
|
168 |
-
loss, preds, targets = self.model_step(batch)
|
169 |
-
|
170 |
-
# update and log metrics
|
171 |
-
self.test_loss(loss)
|
172 |
-
self.test_acc(preds, targets)
|
173 |
-
self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True)
|
174 |
-
self.log("test/acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True)
|
175 |
-
|
176 |
-
def on_test_epoch_end(self) -> None:
|
177 |
-
"""Lightning hook that is called when a test epoch ends."""
|
178 |
-
pass
|
179 |
-
|
180 |
-
def setup(self, stage: str) -> None:
|
181 |
-
"""Lightning hook that is called at the beginning of fit (train + validate), validate,
|
182 |
-
test, or predict.
|
183 |
-
|
184 |
-
This is a good hook when you need to build models dynamically or adjust something about
|
185 |
-
them. This hook is called on every process when using DDP.
|
186 |
-
|
187 |
-
:param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
|
188 |
-
"""
|
189 |
-
if self.hparams.compile and stage == "fit":
|
190 |
-
self.net = torch.compile(self.net)
|
191 |
-
|
192 |
-
def configure_optimizers(self) -> Dict[str, Any]:
|
193 |
-
"""Choose what optimizers and learning-rate schedulers to use in your optimization.
|
194 |
-
Normally you'd need one. But in the case of GANs or similar you might have multiple.
|
195 |
-
|
196 |
-
Examples:
|
197 |
-
https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
|
198 |
-
|
199 |
-
:return: A dict containing the configured optimizers and learning-rate schedulers to be used for training.
|
200 |
-
"""
|
201 |
-
optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())
|
202 |
-
if self.hparams.scheduler is not None:
|
203 |
-
scheduler = self.hparams.scheduler(optimizer=optimizer)
|
204 |
-
return {
|
205 |
-
"optimizer": optimizer,
|
206 |
-
"lr_scheduler": {
|
207 |
-
"scheduler": scheduler,
|
208 |
-
"monitor": "val/loss",
|
209 |
-
"interval": "epoch",
|
210 |
-
"frequency": 1,
|
211 |
-
},
|
212 |
-
}
|
213 |
-
return {"optimizer": optimizer}
|
214 |
-
|
215 |
-
|
216 |
-
if __name__ == "__main__":
|
217 |
-
_ = MNISTLitModule(None, None, None, None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/train.py
CHANGED
@@ -38,6 +38,8 @@ from src.utils import (
|
|
38 |
|
39 |
log = RankedLogger(__name__, rank_zero_only=True)
|
40 |
|
|
|
|
|
41 |
|
42 |
@task_wrapper
|
43 |
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
@@ -67,7 +69,9 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
67 |
logger: List[Logger] = instantiate_loggers(cfg.get("logger"))
|
68 |
|
69 |
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
|
70 |
-
trainer: Trainer = hydra.utils.instantiate(
|
|
|
|
|
71 |
|
72 |
object_dict = {
|
73 |
"cfg": cfg,
|
|
|
38 |
|
39 |
log = RankedLogger(__name__, rank_zero_only=True)
|
40 |
|
41 |
+
torch.set_float32_matmul_precision("medium")
|
42 |
+
|
43 |
|
44 |
@task_wrapper
|
45 |
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
|
69 |
logger: List[Logger] = instantiate_loggers(cfg.get("logger"))
|
70 |
|
71 |
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
|
72 |
+
trainer: Trainer = hydra.utils.instantiate(
|
73 |
+
cfg.trainer, callbacks=callbacks, logger=logger
|
74 |
+
)
|
75 |
|
76 |
object_dict = {
|
77 |
"cfg": cfg,
|
test_bert.ipynb
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"name": "stderr",
|
10 |
+
"output_type": "stream",
|
11 |
+
"text": [
|
12 |
+
"/home/qninh/miniconda3/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
13 |
+
" from .autonotebook import tqdm as notebook_tqdm\n"
|
14 |
+
]
|
15 |
+
}
|
16 |
+
],
|
17 |
+
"source": [
|
18 |
+
"from src.data.mixed_datamodule import MixedDataModule"
|
19 |
+
]
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"cell_type": "code",
|
23 |
+
"execution_count": 2,
|
24 |
+
"metadata": {},
|
25 |
+
"outputs": [],
|
26 |
+
"source": [
|
27 |
+
"datamodule = MixedDataModule(dataset_path=\"./data/mixed\", batch_size=32, num_workers=4, bert_model=\"bert-base-uncased\", tool_capacity=16)"
|
28 |
+
]
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"cell_type": "code",
|
32 |
+
"execution_count": 3,
|
33 |
+
"metadata": {},
|
34 |
+
"outputs": [],
|
35 |
+
"source": [
|
36 |
+
"datamodule.setup(stage=\"fit\")"
|
37 |
+
]
|
38 |
+
},
|
39 |
+
{
|
40 |
+
"cell_type": "code",
|
41 |
+
"execution_count": 4,
|
42 |
+
"metadata": {},
|
43 |
+
"outputs": [],
|
44 |
+
"source": [
|
45 |
+
"train_dataloader = datamodule.train_dataloader()"
|
46 |
+
]
|
47 |
+
},
|
48 |
+
{
|
49 |
+
"cell_type": "code",
|
50 |
+
"execution_count": 5,
|
51 |
+
"metadata": {},
|
52 |
+
"outputs": [
|
53 |
+
{
|
54 |
+
"data": {
|
55 |
+
"text/plain": [
|
56 |
+
"{'instruction': torch.Size([32, 128]),\n",
|
57 |
+
" 'instruction_mask': torch.Size([32, 128]),\n",
|
58 |
+
" 'tool_desc_emb': torch.Size([32, 128]),\n",
|
59 |
+
" 'tool_desc_mask': torch.Size([32, 128])}"
|
60 |
+
]
|
61 |
+
},
|
62 |
+
"execution_count": 5,
|
63 |
+
"metadata": {},
|
64 |
+
"output_type": "execute_result"
|
65 |
+
}
|
66 |
+
],
|
67 |
+
"source": [
|
68 |
+
"# first sample\n",
|
69 |
+
"batch = next(iter(train_dataloader))\n",
|
70 |
+
"{\n",
|
71 |
+
" key: value.shape\n",
|
72 |
+
" for key, value in batch.items()\n",
|
73 |
+
"}"
|
74 |
+
]
|
75 |
+
},
|
76 |
+
{
|
77 |
+
"cell_type": "code",
|
78 |
+
"execution_count": 6,
|
79 |
+
"metadata": {},
|
80 |
+
"outputs": [],
|
81 |
+
"source": [
|
82 |
+
"val_dataloader = datamodule.val_dataloader()"
|
83 |
+
]
|
84 |
+
},
|
85 |
+
{
|
86 |
+
"cell_type": "code",
|
87 |
+
"execution_count": 7,
|
88 |
+
"metadata": {},
|
89 |
+
"outputs": [
|
90 |
+
{
|
91 |
+
"data": {
|
92 |
+
"text/plain": [
|
93 |
+
"{'instruction': torch.Size([32, 128]),\n",
|
94 |
+
" 'instruction_mask': torch.Size([32, 128]),\n",
|
95 |
+
" 'tool_desc_emb': torch.Size([32, 16, 128]),\n",
|
96 |
+
" 'tool_desc_mask': torch.Size([32, 16, 128])}"
|
97 |
+
]
|
98 |
+
},
|
99 |
+
"execution_count": 7,
|
100 |
+
"metadata": {},
|
101 |
+
"output_type": "execute_result"
|
102 |
+
}
|
103 |
+
],
|
104 |
+
"source": [
|
105 |
+
"# first sample\n",
|
106 |
+
"batch = next(iter(val_dataloader))\n",
|
107 |
+
"{\n",
|
108 |
+
" key: value.shape\n",
|
109 |
+
" for key, value in batch.items()\n",
|
110 |
+
"}"
|
111 |
+
]
|
112 |
+
},
|
113 |
+
{
|
114 |
+
"cell_type": "code",
|
115 |
+
"execution_count": null,
|
116 |
+
"metadata": {},
|
117 |
+
"outputs": [],
|
118 |
+
"source": []
|
119 |
+
}
|
120 |
+
],
|
121 |
+
"metadata": {
|
122 |
+
"kernelspec": {
|
123 |
+
"display_name": "swim",
|
124 |
+
"language": "python",
|
125 |
+
"name": "python3"
|
126 |
+
},
|
127 |
+
"language_info": {
|
128 |
+
"codemirror_mode": {
|
129 |
+
"name": "ipython",
|
130 |
+
"version": 3
|
131 |
+
},
|
132 |
+
"file_extension": ".py",
|
133 |
+
"mimetype": "text/x-python",
|
134 |
+
"name": "python",
|
135 |
+
"nbconvert_exporter": "python",
|
136 |
+
"pygments_lexer": "ipython3",
|
137 |
+
"version": "3.10.14"
|
138 |
+
}
|
139 |
+
},
|
140 |
+
"nbformat": 4,
|
141 |
+
"nbformat_minor": 2
|
142 |
+
}
|