qninhdt commited on
Commit
897fe06
·
1 Parent(s): 01377f7
configs/experiment/example.yaml DELETED
@@ -1,41 +0,0 @@
1
- # @package _global_
2
-
3
- # to execute this experiment run:
4
- # python train.py experiment=example
5
-
6
- defaults:
7
- - override /data: mnist
8
- - override /model: mnist
9
- - override /callbacks: default
10
- - override /trainer: default
11
-
12
- # all parameters below will be merged with parameters from default configurations set above
13
- # this allows you to overwrite only specified parameters
14
-
15
- tags: ["mnist", "simple_dense_net"]
16
-
17
- seed: 12345
18
-
19
- trainer:
20
- min_epochs: 10
21
- max_epochs: 10
22
- gradient_clip_val: 0.5
23
-
24
- model:
25
- optimizer:
26
- lr: 0.002
27
- net:
28
- lin1_size: 128
29
- lin2_size: 256
30
- lin3_size: 64
31
- compile: false
32
-
33
- data:
34
- batch_size: 64
35
-
36
- logger:
37
- wandb:
38
- tags: ${tags}
39
- group: "mnist"
40
- aim:
41
- experiment: "mnist"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/experiment/miniagent-bert-mlp.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /data: mixed
5
+ - override /model: miniagent
6
+ - override /callbacks: default
7
+ - override /trainer: gpu
8
+
9
+ seed: 42
10
+
11
+ model:
12
+ lr: 0.0001
13
+ bert_model: bert-base-uncased
14
+
15
+ inst_proj_model:
16
+ _target_: src.models.mlp_module.MLPProjection
17
+ input_dim: 768
18
+ hidden_dim: 768
19
+ output_dim: 768
20
+
21
+ tool_proj_model:
22
+ _target_: src.models.mlp_module.MLPProjection
23
+ input_dim: 768
24
+ hidden_dim: 768
25
+ output_dim: 768
26
+
27
+ pred_model:
28
+ _target_: src.models.mlp_module.MLPPrediction
29
+ input_dim: 768
30
+ use_abs_diff: false
31
+ use_mult: false
32
+
33
+ data:
34
+ bert_model: bert-base-uncased
35
+ batch_size: 32
36
+ tool_capacity: 16
configs/model/miniagent.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ _target_: src.models.miniagent_module.MiniAgentModule
2
+
3
+ bert_model: bert-base-uncased
configs/model/mnist.yaml DELETED
@@ -1,25 +0,0 @@
1
- _target_: src.models.mnist_module.MNISTLitModule
2
-
3
- optimizer:
4
- _target_: torch.optim.Adam
5
- _partial_: true
6
- lr: 0.001
7
- weight_decay: 0.0
8
-
9
- scheduler:
10
- _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
11
- _partial_: true
12
- mode: min
13
- factor: 0.1
14
- patience: 10
15
-
16
- net:
17
- _target_: src.models.components.simple_dense_net.SimpleDenseNet
18
- input_size: 784
19
- lin1_size: 64
20
- lin2_size: 128
21
- lin3_size: 64
22
- output_size: 10
23
-
24
- # compile model for faster training with pytorch 2.0
25
- compile: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/train.yaml CHANGED
@@ -4,11 +4,11 @@
4
  # order of defaults determines the order in which configs override each other
5
  defaults:
6
  - _self_
7
- - data: mnist
8
- - model: mnist
9
  - callbacks: default
10
- - logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
11
- - trainer: default
12
  - paths: default
13
  - extras: default
14
  - hydra: default
@@ -46,4 +46,4 @@ test: True
46
  ckpt_path: null
47
 
48
  # seed for random number generators in pytorch, numpy and python.random
49
- seed: null
 
4
  # order of defaults determines the order in which configs override each other
5
  defaults:
6
  - _self_
7
+ - data: mixed
8
+ - model: miniagent
9
  - callbacks: default
10
+ - logger: wandb # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
11
+ - trainer: gpu
12
  - paths: default
13
  - extras: default
14
  - hydra: default
 
46
  ckpt_path: null
47
 
48
  # seed for random number generators in pytorch, numpy and python.random
49
+ seed: 42
notebooks/test_bert.ipynb DELETED
@@ -1,244 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 9,
6
- "metadata": {},
7
- "outputs": [],
8
- "source": [
9
- "import torch\n",
10
- "from transformers import BertModel, BertTokenizer\n",
11
- "\n",
12
- "model_name = \"bert-base-uncased\"\n",
13
- "\n",
14
- "tokenizer = BertTokenizer.from_pretrained(model_name)\n",
15
- "model = BertModel.from_pretrained(model_name).cuda()"
16
- ]
17
- },
18
- {
19
- "cell_type": "code",
20
- "execution_count": 19,
21
- "metadata": {},
22
- "outputs": [
23
- {
24
- "name": "stdout",
25
- "output_type": "stream",
26
- "text": [
27
- "tensor([[ 101, 2182, 2003, 2070, 3793, 2000, 4372, 16044, 102]],\n",
28
- " device='cuda:0')\n"
29
- ]
30
- },
31
- {
32
- "data": {
33
- "text/plain": [
34
- "BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[-0.0549, 0.1053, -0.1065, ..., -0.3550, 0.0686, 0.6506],\n",
35
- " [-0.5759, -0.3650, -0.1383, ..., -0.6782, 0.2092, -0.1639],\n",
36
- " [-0.1641, -0.5597, 0.0150, ..., -0.1603, -0.1346, 0.6216],\n",
37
- " ...,\n",
38
- " [ 0.2448, 0.1254, 0.1587, ..., -0.2749, -0.1163, 0.8809],\n",
39
- " [ 0.0481, 0.4950, -0.2827, ..., -0.6097, -0.1212, 0.2527],\n",
40
- " [ 0.9046, 0.2137, -0.5897, ..., 0.3040, -0.6172, -0.1950]]],\n",
41
- " device='cuda:0'), pooler_output=tensor([[-8.3875e-01, -2.8128e-01, -4.9307e-02, 5.9384e-01, 1.9040e-01,\n",
42
- " -1.3149e-01, 8.4180e-01, 1.8583e-01, -3.4094e-03, -9.9988e-01,\n",
43
- " 2.5326e-02, 3.7435e-01, 9.6628e-01, -2.6636e-01, 8.5380e-01,\n",
44
- " -4.6390e-01, 4.3777e-02, -5.0858e-01, 2.4810e-01, -5.1466e-01,\n",
45
- " 4.9031e-01, 9.9296e-01, 5.5225e-01, 2.2509e-01, 3.0543e-01,\n",
46
- " 5.2188e-01, -4.6495e-01, 8.9979e-01, 9.2244e-01, 6.7778e-01,\n",
47
- " -5.8705e-01, 1.7724e-01, -9.7322e-01, -1.3662e-01, -1.8714e-01,\n",
48
- " -9.7217e-01, 2.5207e-01, -7.0827e-01, 2.4030e-04, 2.0537e-02,\n",
49
- " -8.6379e-01, 3.1709e-01, 9.9879e-01, -3.7660e-01, 7.7044e-02,\n",
50
- " -2.8340e-01, -9.9993e-01, 1.8332e-01, -8.3515e-01, 1.1565e-01,\n",
51
- " 5.0053e-02, -4.2741e-02, 1.2181e-01, 3.1478e-01, 3.5848e-01,\n",
52
- " 2.1106e-01, -1.7044e-01, 1.2083e-01, -8.6361e-02, -5.1711e-01,\n",
53
- " -4.9923e-01, 2.6499e-01, -2.1260e-01, -8.7598e-01, 7.0440e-03,\n",
54
- " -2.9079e-01, -1.7843e-02, -2.3522e-01, -1.9458e-02, -8.6378e-02,\n",
55
- " 8.0144e-01, 1.1972e-01, 1.0071e-01, -7.2592e-01, -1.2743e-01,\n",
56
- " 9.6182e-02, -4.8009e-01, 1.0000e+00, -2.8258e-01, -9.5353e-01,\n",
57
- " 1.4182e-01, 6.0877e-02, 4.2072e-01, 5.1125e-01, -3.7690e-01,\n",
58
- " -1.0000e+00, 3.7079e-01, -1.1285e-01, -9.7725e-01, 5.0917e-02,\n",
59
- " 4.0374e-01, -8.9133e-02, -2.4458e-01, 3.9134e-01, -1.7817e-01,\n",
60
- " -1.8376e-01, -2.1744e-01, -6.2693e-02, -1.6466e-01, -1.1514e-01,\n",
61
- " 7.3975e-02, -1.4703e-01, -2.8427e-02, -3.2398e-01, 1.3225e-01,\n",
62
- " -2.4507e-01, -3.8186e-01, 2.8680e-01, -2.2497e-01, 5.3901e-01,\n",
63
- " 3.1464e-01, -2.2506e-01, 1.5128e-01, -9.1770e-01, 5.3941e-01,\n",
64
- " -2.4810e-01, -9.6917e-01, -4.1745e-01, -9.7295e-01, 5.8796e-01,\n",
65
- " 8.2848e-02, -5.6085e-02, 9.2621e-01, 4.6422e-01, 2.3537e-01,\n",
66
- " 5.4746e-02, -2.8870e-01, -1.0000e+00, -2.8243e-01, -3.2437e-01,\n",
67
- " 1.5666e-01, -7.7765e-02, -9.6036e-01, -9.2680e-01, 5.7584e-01,\n",
68
- " 9.4303e-01, 9.6574e-02, 9.9754e-01, -1.0749e-01, 8.8828e-01,\n",
69
- " 1.8555e-01, -1.5084e-01, -2.7866e-01, -3.2091e-01, 5.0980e-01,\n",
70
- " 3.2259e-01, -6.5085e-01, 1.1357e-01, 2.4492e-02, 2.6022e-02,\n",
71
- " -2.6347e-01, -2.0440e-01, -3.4154e-02, -8.8905e-01, -3.0404e-01,\n",
72
- " 8.9583e-01, 5.9022e-02, -1.0150e-02, 6.4437e-01, -1.0694e-01,\n",
73
- " -3.0653e-01, 7.8505e-01, 4.7535e-01, 2.1861e-01, 1.8127e-02,\n",
74
- " 3.4923e-01, -8.8995e-02, 4.1078e-01, -8.3798e-01, 1.8740e-01,\n",
75
- " 3.0741e-01, -1.6962e-01, -9.0649e-03, -9.5766e-01, -2.1796e-01,\n",
76
- " 4.8940e-01, 9.7536e-01, 7.1374e-01, 2.6859e-01, 2.7887e-01,\n",
77
- " -2.5057e-01, 2.8507e-01, -9.0920e-01, 9.5826e-01, -4.7828e-02,\n",
78
- " 1.8272e-01, 2.4859e-01, -1.3071e-02, -8.0613e-01, -4.1807e-01,\n",
79
- " 7.6876e-01, -2.1496e-01, -7.9618e-01, 6.0786e-02, -4.1543e-01,\n",
80
- " -3.0691e-01, -1.7626e-01, 4.7289e-01, -2.2226e-01, -3.4807e-01,\n",
81
- " 1.9591e-02, 8.9078e-01, 9.3557e-01, 6.5839e-01, -4.2014e-01,\n",
82
- " 4.0678e-01, -8.4855e-01, -3.4254e-01, 4.9145e-02, 1.6904e-01,\n",
83
- " 4.2381e-02, 9.8512e-01, -3.0301e-01, -1.8377e-01, -8.8705e-01,\n",
84
- " -9.6226e-01, -7.3838e-02, -8.2263e-01, -1.6584e-02, -6.1115e-01,\n",
85
- " 3.1010e-01, 2.1122e-01, -5.6624e-02, 3.1030e-01, -9.6320e-01,\n",
86
- " -7.0888e-01, 2.6261e-01, -2.6109e-01, 3.0822e-01, -2.0598e-01,\n",
87
- " 6.7753e-01, 1.7591e-01, -5.3198e-01, 5.1929e-01, 8.7353e-01,\n",
88
- " 1.1513e-02, -6.4033e-01, 7.1928e-01, -2.0344e-01, 8.7310e-01,\n",
89
- " -5.2567e-01, 9.6434e-01, 3.6636e-01, 3.8012e-01, -8.5828e-01,\n",
90
- " 6.6280e-02, -8.5531e-01, 1.3883e-01, -1.5715e-01, -5.6898e-01,\n",
91
- " 7.5188e-02, 4.8760e-01, 2.2753e-01, 7.3960e-01, -4.1109e-01,\n",
92
- " 9.8629e-01, -5.8852e-01, -9.2256e-01, -1.0970e-01, -4.9473e-03,\n",
93
- " -9.6941e-01, 2.2550e-01, 1.4062e-01, -4.1525e-02, -3.5449e-01,\n",
94
- " -4.7126e-01, -9.2133e-01, 8.0581e-01, 4.2704e-02, 9.7325e-01,\n",
95
- " 1.5150e-02, -8.2761e-01, -2.6887e-01, -8.6111e-01, -2.1065e-01,\n",
96
- " -5.0521e-02, 4.8313e-01, -2.3626e-01, -8.9746e-01, 3.6525e-01,\n",
97
- " 5.3635e-01, 3.2864e-01, 1.7254e-01, 9.8927e-01, 9.9986e-01,\n",
98
- " 9.5408e-01, 7.9837e-01, 8.2591e-01, -9.7353e-01, -3.6506e-01,\n",
99
- " 9.9994e-01, -6.9189e-01, -9.9999e-01, -9.0511e-01, -4.9764e-01,\n",
100
- " 3.5348e-01, -1.0000e+00, -7.4901e-02, 4.5044e-02, -8.3963e-01,\n",
101
- " -2.2185e-01, 9.6232e-01, 9.6733e-01, -1.0000e+00, 8.2459e-01,\n",
102
- " 9.1231e-01, -4.8469e-01, 3.5679e-01, -2.6093e-01, 9.5376e-01,\n",
103
- " 3.0112e-01, 4.6869e-01, -6.6399e-02, 2.9968e-01, -3.6107e-01,\n",
104
- " -7.6258e-01, 1.2195e-01, 1.8867e-01, 8.9893e-01, 7.6925e-02,\n",
105
- " -6.9103e-01, -8.5666e-01, 6.2071e-02, 5.7222e-02, -3.0562e-01,\n",
106
- " -9.2684e-01, -1.4870e-01, -2.9001e-01, 5.3266e-01, 6.7110e-02,\n",
107
- " 1.9160e-01, -7.2418e-01, 2.3380e-01, -5.4946e-01, 3.2379e-01,\n",
108
- " 5.2877e-01, -9.0674e-01, -4.9937e-01, -4.4697e-02, -4.0249e-01,\n",
109
- " 1.7510e-01, -9.4833e-01, 9.4143e-01, -2.0001e-01, 3.7983e-01,\n",
110
- " 1.0000e+00, 4.7085e-02, -7.9982e-01, 3.6463e-01, 1.4701e-01,\n",
111
- " -4.6405e-01, 1.0000e+00, 4.2325e-01, -9.6105e-01, -3.5190e-01,\n",
112
- " 3.0839e-01, -3.3907e-01, -3.5049e-01, 9.9734e-01, -8.8495e-02,\n",
113
- " 1.4489e-01, 2.2866e-01, 9.4882e-01, -9.7813e-01, 8.1165e-01,\n",
114
- " -8.5137e-01, -9.2814e-01, 9.3252e-01, 8.8071e-01, -1.3548e-01,\n",
115
- " -6.6040e-01, -1.6448e-02, -1.4088e-01, 1.9752e-01, -9.2127e-01,\n",
116
- " 4.6474e-01, 2.7750e-01, -2.0809e-02, 8.2319e-01, -7.3596e-01,\n",
117
- " -4.2808e-01, 3.5810e-01, 3.2605e-02, 2.9415e-01, 3.4245e-01,\n",
118
- " 3.8726e-01, -1.9187e-01, -2.6267e-02, -1.0283e-01, -2.9876e-01,\n",
119
- " -9.5093e-01, 1.9958e-01, 1.0000e+00, 1.1309e-01, -6.9028e-02,\n",
120
- " -9.7448e-02, 1.6147e-02, -2.6026e-01, 3.4048e-01, 4.3058e-01,\n",
121
- " -2.0093e-01, -7.9416e-01, 8.8086e-02, -8.1918e-01, -9.6890e-01,\n",
122
- " 5.9739e-01, 1.2771e-01, -2.6104e-01, 9.9899e-01, 1.4233e-01,\n",
123
- " 1.2059e-01, 4.0483e-02, 5.1663e-01, -2.8824e-02, 4.7300e-01,\n",
124
- " -2.7801e-01, 9.5372e-01, -2.0937e-01, 4.0244e-01, 7.6304e-01,\n",
125
- " 3.1611e-02, -1.7255e-01, -5.7354e-01, 3.6617e-03, -8.8957e-01,\n",
126
- " 2.3939e-02, -9.0446e-01, 9.4090e-01, 2.0094e-02, 2.2362e-01,\n",
127
- " 6.5433e-02, 7.9535e-02, 1.0000e+00, -3.0244e-01, 5.7777e-01,\n",
128
- " -1.9970e-01, 7.3194e-01, -9.7151e-01, -6.7480e-01, -3.0641e-01,\n",
129
- " 4.5864e-02, 2.2067e-01, -1.9735e-01, 1.2545e-01, -9.4334e-01,\n",
130
- " 2.7122e-04, -1.1515e-01, -9.5089e-01, -9.7541e-01, 4.2593e-01,\n",
131
- " 6.3330e-01, -5.9799e-02, -7.2403e-01, -5.1824e-01, -5.4331e-01,\n",
132
- " 3.2069e-01, -1.8098e-01, -8.6600e-01, 4.9465e-01, -2.0962e-01,\n",
133
- " 3.1231e-01, -1.9074e-01, 4.5443e-01, -7.5650e-02, 8.2584e-01,\n",
134
- " -1.6397e-01, -1.7857e-02, -5.9298e-02, -7.1197e-01, 6.4738e-01,\n",
135
- " -7.2854e-01, -2.4250e-01, -8.7492e-02, 1.0000e+00, -3.0222e-01,\n",
136
- " 1.1153e-01, 6.3838e-01, 5.6054e-01, -9.7328e-02, 1.4136e-01,\n",
137
- " 2.4002e-01, 7.8386e-02, 4.0799e-01, 1.1524e-01, -4.4807e-01,\n",
138
- " -1.6638e-01, 4.3632e-01, -1.1570e-01, -1.5105e-01, 6.8226e-01,\n",
139
- " 5.2354e-01, 1.7305e-03, 2.1510e-02, 4.1700e-03, 9.9460e-01,\n",
140
- " 1.7238e-02, -1.2763e-01, -3.8213e-01, 2.8577e-02, -2.2359e-01,\n",
141
- " -3.2380e-01, 1.0000e+00, 2.7486e-01, 2.9497e-02, -9.7830e-01,\n",
142
- " -1.3164e-01, -8.6022e-01, 9.9958e-01, 6.8541e-01, -7.7604e-01,\n",
143
- " 4.9727e-01, 2.5391e-01, -1.3586e-01, 5.6488e-01, -1.2485e-01,\n",
144
- " -2.2596e-01, 8.4730e-02, 9.8264e-03, 9.0287e-01, -3.3443e-01,\n",
145
- " -9.3577e-01, -5.7664e-01, 3.0516e-01, -9.2533e-01, 9.8399e-01,\n",
146
- " -5.3663e-01, -1.1600e-01, -2.4484e-01, 1.1554e-01, 4.8479e-01,\n",
147
- " -1.2667e-01, -9.4730e-01, -1.2690e-01, -8.1969e-03, 9.2896e-01,\n",
148
- " 8.2588e-02, -3.9092e-01, -8.7582e-01, -1.1210e-01, 1.4961e-01,\n",
149
- " 3.9152e-02, -8.8999e-01, 9.4467e-01, -9.5203e-01, 2.1941e-01,\n",
150
- " 9.9999e-01, 2.9063e-01, -4.6173e-01, 8.3155e-02, -2.9862e-01,\n",
151
- " 2.1412e-01, 1.2559e-01, 4.1466e-01, -9.2596e-01, -1.5259e-01,\n",
152
- " -1.1740e-01, 1.7836e-01, -7.1510e-02, 6.5254e-02, 5.5836e-01,\n",
153
- " 1.7951e-01, -3.6205e-01, -4.6712e-01, -6.8167e-02, 2.7963e-01,\n",
154
- " 6.4896e-01, -1.6497e-01, -1.5638e-01, 1.0768e-01, -1.3088e-01,\n",
155
- " -7.8675e-01, -1.7540e-01, -2.6264e-01, -9.9611e-01, 5.5635e-01,\n",
156
- " -1.0000e+00, -1.6508e-01, -4.3545e-01, -1.5236e-01, 7.6126e-01,\n",
157
- " 3.3181e-01, 1.6770e-01, -6.2815e-01, 2.1790e-01, 7.3190e-01,\n",
158
- " 6.5308e-01, -1.4397e-01, 9.8390e-02, -6.1118e-01, 1.8836e-01,\n",
159
- " -9.3982e-02, 2.3265e-01, 1.1375e-01, 7.3512e-01, -8.6529e-02,\n",
160
- " 1.0000e+00, 4.7590e-02, -4.0572e-01, -9.1862e-01, 2.3070e-01,\n",
161
- " -1.1442e-01, 9.9996e-01, -8.1991e-01, -9.1503e-01, 2.0697e-01,\n",
162
- " -5.1682e-01, -7.5368e-01, 2.1045e-01, 5.6595e-03, -6.4855e-01,\n",
163
- " -4.1527e-01, 9.1653e-01, 7.3400e-01, -4.2874e-01, 2.3967e-01,\n",
164
- " -2.4068e-01, -3.3491e-01, -4.3406e-02, -2.8789e-02, 9.7527e-01,\n",
165
- " 3.9052e-01, 8.5009e-01, 4.9386e-01, -7.6392e-03, 9.4010e-01,\n",
166
- " 2.0048e-01, 2.7198e-01, 6.1171e-02, 1.0000e+00, 2.8885e-01,\n",
167
- " -9.1301e-01, 2.4983e-01, -9.7476e-01, -1.2421e-01, -9.2382e-01,\n",
168
- " 1.9579e-01, 8.3662e-02, 8.1386e-01, -2.0425e-01, 9.2494e-01,\n",
169
- " 5.8940e-02, 2.2477e-02, 1.0070e-01, 3.0503e-01, 2.4897e-01,\n",
170
- " -8.5819e-01, -9.6919e-01, -9.6845e-01, 3.7154e-01, -3.5970e-01,\n",
171
- " -1.5696e-02, 2.3336e-01, 6.0670e-02, 3.0880e-01, 2.7858e-01,\n",
172
- " -1.0000e+00, 8.9189e-01, 2.8790e-01, -9.3187e-02, 9.3518e-01,\n",
173
- " 2.2574e-01, 3.3557e-01, 1.7759e-01, -9.6978e-01, -9.1210e-01,\n",
174
- " -2.6055e-01, -3.0031e-01, 7.1668e-01, 5.6109e-01, 7.4522e-01,\n",
175
- " 2.2312e-01, -4.3589e-01, -3.5411e-01, 2.3049e-01, -4.5362e-01,\n",
176
- " -9.8112e-01, 3.6383e-01, 7.7364e-02, -9.1412e-01, 9.2768e-01,\n",
177
- " -4.3448e-01, -7.7045e-02, 5.9378e-01, -1.9775e-01, 8.6886e-01,\n",
178
- " 6.9895e-01, 3.6932e-01, 3.9024e-02, 4.8328e-01, 7.9945e-01,\n",
179
- " 9.2656e-01, 9.7223e-01, -6.6530e-02, 6.7691e-01, 1.3522e-01,\n",
180
- " 3.0958e-01, 5.4615e-01, -9.1295e-01, 8.0810e-02, 1.6158e-01,\n",
181
- " 6.6348e-02, 1.2711e-01, -1.4847e-01, -9.2298e-01, 4.7369e-01,\n",
182
- " -1.1713e-01, 4.0759e-01, -3.2167e-01, 1.7547e-01, -3.2665e-01,\n",
183
- " -1.0958e-01, -7.2304e-01, -3.8936e-01, 4.9849e-01, 1.8104e-01,\n",
184
- " 8.7080e-01, 4.7525e-01, 5.3199e-02, -5.8020e-01, -7.1312e-02,\n",
185
- " 2.1172e-01, -8.8360e-01, 8.8993e-01, 9.6810e-02, 4.6417e-01,\n",
186
- " 8.0528e-02, -1.0677e-01, 7.8678e-01, -2.3742e-01, -3.0270e-01,\n",
187
- " -1.4502e-01, -7.0351e-01, 7.5302e-01, -3.6792e-01, -4.0656e-01,\n",
188
- " -2.9551e-01, 5.9936e-01, 2.2145e-01, 9.9465e-01, 3.4475e-02,\n",
189
- " -5.7931e-02, -2.5533e-01, -2.1419e-01, 3.0497e-01, -2.0330e-01,\n",
190
- " -1.0000e+00, 3.2221e-01, -2.2759e-02, -1.8658e-02, -9.2918e-02,\n",
191
- " 2.2858e-02, -1.0043e-01, -9.5813e-01, -1.5660e-01, 2.8550e-01,\n",
192
- " -2.3666e-02, -4.3055e-01, -1.8355e-01, 3.7819e-01, 4.1632e-01,\n",
193
- " 5.1126e-01, 8.4217e-01, 3.9612e-02, 4.3960e-01, 4.4244e-01,\n",
194
- " -2.3278e-01, -5.7333e-01, 8.3078e-01]], device='cuda:0'), hidden_states=None, past_key_values=None, attentions=None, cross_attentions=None)"
195
- ]
196
- },
197
- "execution_count": 19,
198
- "metadata": {},
199
- "output_type": "execute_result"
200
- }
201
- ],
202
- "source": [
203
- "input_text = \"Here is some text to encode\"\n",
204
- "input_ids = tokenizer.encode(input_text, add_special_tokens=True)\n",
205
- "input_ids = torch.tensor([input_ids]).cuda()\n",
206
- "\n",
207
- "print(input_ids)\n",
208
- "\n",
209
- "with torch.no_grad():\n",
210
- " last_hidden_states = model(input_ids) # Models outputs are now tuples\n",
211
- "\n",
212
- "last_hidden_states"
213
- ]
214
- },
215
- {
216
- "cell_type": "code",
217
- "execution_count": null,
218
- "metadata": {},
219
- "outputs": [],
220
- "source": []
221
- }
222
- ],
223
- "metadata": {
224
- "kernelspec": {
225
- "display_name": "swim",
226
- "language": "python",
227
- "name": "python3"
228
- },
229
- "language_info": {
230
- "codemirror_mode": {
231
- "name": "ipython",
232
- "version": 3
233
- },
234
- "file_extension": ".py",
235
- "mimetype": "text/x-python",
236
- "name": "python",
237
- "nbconvert_exporter": "python",
238
- "pygments_lexer": "ipython3",
239
- "version": "3.12.6"
240
- }
241
- },
242
- "nbformat": 4,
243
- "nbformat_minor": 2
244
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/eval.py CHANGED
@@ -1,5 +1,6 @@
1
  from typing import Any, Dict, List, Tuple
2
 
 
3
  import hydra
4
  import rootutils
5
  from lightning import LightningDataModule, LightningModule, Trainer
@@ -34,6 +35,8 @@ from src.utils import (
34
 
35
  log = RankedLogger(__name__, rank_zero_only=True)
36
 
 
 
37
 
38
  @task_wrapper
39
  def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
 
1
  from typing import Any, Dict, List, Tuple
2
 
3
+ import torch
4
  import hydra
5
  import rootutils
6
  from lightning import LightningDataModule, LightningModule, Trainer
 
35
 
36
  log = RankedLogger(__name__, rank_zero_only=True)
37
 
38
+ torch.set_float32_matmul_precision("medium")
39
+
40
 
41
  @task_wrapper
42
  def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
src/models/components/__init__.py DELETED
File without changes
src/models/components/simple_dense_net.py DELETED
@@ -1,54 +0,0 @@
1
- import torch
2
- from torch import nn
3
-
4
-
5
- class SimpleDenseNet(nn.Module):
6
- """A simple fully-connected neural net for computing predictions."""
7
-
8
- def __init__(
9
- self,
10
- input_size: int = 784,
11
- lin1_size: int = 256,
12
- lin2_size: int = 256,
13
- lin3_size: int = 256,
14
- output_size: int = 10,
15
- ) -> None:
16
- """Initialize a `SimpleDenseNet` module.
17
-
18
- :param input_size: The number of input features.
19
- :param lin1_size: The number of output features of the first linear layer.
20
- :param lin2_size: The number of output features of the second linear layer.
21
- :param lin3_size: The number of output features of the third linear layer.
22
- :param output_size: The number of output features of the final linear layer.
23
- """
24
- super().__init__()
25
-
26
- self.model = nn.Sequential(
27
- nn.Linear(input_size, lin1_size),
28
- nn.BatchNorm1d(lin1_size),
29
- nn.ReLU(),
30
- nn.Linear(lin1_size, lin2_size),
31
- nn.BatchNorm1d(lin2_size),
32
- nn.ReLU(),
33
- nn.Linear(lin2_size, lin3_size),
34
- nn.BatchNorm1d(lin3_size),
35
- nn.ReLU(),
36
- nn.Linear(lin3_size, output_size),
37
- )
38
-
39
- def forward(self, x: torch.Tensor) -> torch.Tensor:
40
- """Perform a single forward pass through the network.
41
-
42
- :param x: The input tensor.
43
- :return: A tensor of predictions.
44
- """
45
- batch_size, channels, width, height = x.size()
46
-
47
- # (batch, 1, width, height) -> (batch, 1*width*height)
48
- x = x.view(batch_size, -1)
49
-
50
- return self.model(x)
51
-
52
-
53
- if __name__ == "__main__":
54
- _ = SimpleDenseNet()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/miniagent_module.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from lightning import LightningModule
7
+ from torchmetrics import MaxMetric, MeanMetric
8
+ from torchmetrics.classification.accuracy import Accuracy
9
+
10
+ from transformers import BertModel
11
+
12
+
13
+ class MiniAgentModule(LightningModule):
14
+ def __init__(
15
+ self,
16
+ bert_model: str,
17
+ inst_proj_model: nn.Module,
18
+ tool_proj_model: nn.Module,
19
+ pred_model: nn.Module,
20
+ lr: float,
21
+ ) -> None:
22
+ super().__init__()
23
+
24
+ self.save_hyperparameters(
25
+ logger=False, ignore=["inst_proj_model", "tool_proj_model", "pred_model"]
26
+ )
27
+
28
+ self.bert_model = BertModel.from_pretrained(bert_model)
29
+ self.bert_model.eval()
30
+ self.bert_model.requires_grad_(False)
31
+
32
+ self.inst_proj_model = inst_proj_model
33
+ self.tool_proj_model = tool_proj_model
34
+ self.pred_model = pred_model
35
+
36
+ self.lr = lr
37
+
38
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
39
+ pass
40
+
41
+ def on_train_start(self) -> None:
42
+ pass
43
+
44
+ def training_step(
45
+ self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
46
+ ) -> torch.Tensor:
47
+ B = batch["inst_ids"].shape[0]
48
+
49
+ inst_ids = batch["inst_ids"]
50
+ inst_mask = batch["inst_mask"]
51
+ tool_ids = batch["tool_desc_ids"]
52
+ tool_mask = batch["tool_desc_mask"]
53
+
54
+ inst_z = self.bert_model(inst_ids, inst_mask, return_dict=False)[1]
55
+ tool_z = self.bert_model(tool_ids, tool_mask, return_dict=False)[1]
56
+
57
+ inst_emb = self.inst_proj_model(inst_z)
58
+ tool_emb = self.tool_proj_model(tool_z)
59
+
60
+ inst_emb_r = inst_emb.unsqueeze(0).repeat(B, 1, 1).view(B * B, -1)
61
+ tool_emb_r = tool_emb.unsqueeze(1).repeat(1, B, 1).view(B * B, -1)
62
+
63
+ pred = self.pred_model(inst_emb_r, tool_emb_r) # [BxB, 1]
64
+ pred = pred.view(B, B) # [B, B]
65
+
66
+ # mask out the diagonal
67
+ target = torch.eye(B, device=pred.device).float()
68
+
69
+ loss = F.binary_cross_entropy_with_logits(pred, target)
70
+
71
+ self.log("train/loss", loss, on_step=True, sync_dist=True, prog_bar=True)
72
+
73
+ return loss
74
+
75
+ def on_train_epoch_end(self) -> None:
76
+ pass
77
+
78
+ def validation_step(
79
+ self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
80
+ ) -> None:
81
+ pass
82
+
83
+ def on_validation_epoch_end(self) -> None:
84
+ pass
85
+
86
+ def test_step(
87
+ self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
88
+ ) -> None:
89
+ pass
90
+
91
+ def on_test_epoch_end(self) -> None:
92
+ pass
93
+
94
+ def configure_optimizers(self):
95
+ opt = torch.optim.AdamW(self.parameters(), lr=self.lr)
96
+ return opt
src/models/mlp_module.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class MLPProjection(nn.Module):
7
+ def __init__(self, input_dim, hidden_dim, output_dim):
8
+ super().__init__()
9
+ self.linear1 = nn.Linear(input_dim, hidden_dim)
10
+ self.linear2 = nn.Linear(hidden_dim, output_dim)
11
+
12
+ def forward(self, x_output):
13
+ # only use first token ([CLS]) of each output
14
+ x = x_output
15
+
16
+ x = self.linear1(x)
17
+ x = F.silu(x)
18
+ x = self.linear2(x)
19
+
20
+ return x
21
+
22
+
23
+ class MLPPrediction(nn.Module):
24
+ def __init__(self, input_dim, use_abs_diff=False, use_mult=False):
25
+ super().__init__()
26
+
27
+ self.use_abs_diff = use_abs_diff
28
+ self.use_mult = use_mult
29
+
30
+ real_input_dim = input_dim * (2 + int(use_abs_diff) + int(use_mult))
31
+
32
+ self.mlp = nn.Sequential(
33
+ nn.Linear(real_input_dim, 1024),
34
+ nn.SiLU(),
35
+ nn.Linear(1024, 512),
36
+ nn.SiLU(),
37
+ nn.Linear(512, 256),
38
+ nn.SiLU(),
39
+ nn.Linear(256, 1),
40
+ )
41
+
42
+ def forward(self, x1, x2):
43
+ x = torch.cat([x1, x2], dim=1)
44
+
45
+ if self.use_abs_diff:
46
+ x_diff = torch.abs(x1 - x2)
47
+ x = torch.cat([x, x_diff], dim=1)
48
+
49
+ if self.use_mult:
50
+ x_mult = x1 * x2
51
+ x = torch.cat([x, x_mult], dim=1)
52
+
53
+ x = self.mlp(x)
54
+
55
+ return x
src/models/mnist_module.py DELETED
@@ -1,217 +0,0 @@
1
- from typing import Any, Dict, Tuple
2
-
3
- import torch
4
- from lightning import LightningModule
5
- from torchmetrics import MaxMetric, MeanMetric
6
- from torchmetrics.classification.accuracy import Accuracy
7
-
8
-
9
- class MNISTLitModule(LightningModule):
10
- """Example of a `LightningModule` for MNIST classification.
11
-
12
- A `LightningModule` implements 8 key methods:
13
-
14
- ```python
15
- def __init__(self):
16
- # Define initialization code here.
17
-
18
- def setup(self, stage):
19
- # Things to setup before each stage, 'fit', 'validate', 'test', 'predict'.
20
- # This hook is called on every process when using DDP.
21
-
22
- def training_step(self, batch, batch_idx):
23
- # The complete training step.
24
-
25
- def validation_step(self, batch, batch_idx):
26
- # The complete validation step.
27
-
28
- def test_step(self, batch, batch_idx):
29
- # The complete test step.
30
-
31
- def predict_step(self, batch, batch_idx):
32
- # The complete predict step.
33
-
34
- def configure_optimizers(self):
35
- # Define and configure optimizers and LR schedulers.
36
- ```
37
-
38
- Docs:
39
- https://lightning.ai/docs/pytorch/latest/common/lightning_module.html
40
- """
41
-
42
- def __init__(
43
- self,
44
- net: torch.nn.Module,
45
- optimizer: torch.optim.Optimizer,
46
- scheduler: torch.optim.lr_scheduler,
47
- compile: bool,
48
- ) -> None:
49
- """Initialize a `MNISTLitModule`.
50
-
51
- :param net: The model to train.
52
- :param optimizer: The optimizer to use for training.
53
- :param scheduler: The learning rate scheduler to use for training.
54
- """
55
- super().__init__()
56
-
57
- # this line allows to access init params with 'self.hparams' attribute
58
- # also ensures init params will be stored in ckpt
59
- self.save_hyperparameters(logger=False)
60
-
61
- self.net = net
62
-
63
- # loss function
64
- self.criterion = torch.nn.CrossEntropyLoss()
65
-
66
- # metric objects for calculating and averaging accuracy across batches
67
- self.train_acc = Accuracy(task="multiclass", num_classes=10)
68
- self.val_acc = Accuracy(task="multiclass", num_classes=10)
69
- self.test_acc = Accuracy(task="multiclass", num_classes=10)
70
-
71
- # for averaging loss across batches
72
- self.train_loss = MeanMetric()
73
- self.val_loss = MeanMetric()
74
- self.test_loss = MeanMetric()
75
-
76
- # for tracking best so far validation accuracy
77
- self.val_acc_best = MaxMetric()
78
-
79
- def forward(self, x: torch.Tensor) -> torch.Tensor:
80
- """Perform a forward pass through the model `self.net`.
81
-
82
- :param x: A tensor of images.
83
- :return: A tensor of logits.
84
- """
85
- return self.net(x)
86
-
87
- def on_train_start(self) -> None:
88
- """Lightning hook that is called when training begins."""
89
- # by default lightning executes validation step sanity checks before training starts,
90
- # so it's worth to make sure validation metrics don't store results from these checks
91
- self.val_loss.reset()
92
- self.val_acc.reset()
93
- self.val_acc_best.reset()
94
-
95
- def model_step(
96
- self, batch: Tuple[torch.Tensor, torch.Tensor]
97
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
98
- """Perform a single model step on a batch of data.
99
-
100
- :param batch: A batch of data (a tuple) containing the input tensor of images and target labels.
101
-
102
- :return: A tuple containing (in order):
103
- - A tensor of losses.
104
- - A tensor of predictions.
105
- - A tensor of target labels.
106
- """
107
- x, y = batch
108
- logits = self.forward(x)
109
- loss = self.criterion(logits, y)
110
- preds = torch.argmax(logits, dim=1)
111
- return loss, preds, y
112
-
113
- def training_step(
114
- self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
115
- ) -> torch.Tensor:
116
- """Perform a single training step on a batch of data from the training set.
117
-
118
- :param batch: A batch of data (a tuple) containing the input tensor of images and target
119
- labels.
120
- :param batch_idx: The index of the current batch.
121
- :return: A tensor of losses between model predictions and targets.
122
- """
123
- loss, preds, targets = self.model_step(batch)
124
-
125
- # update and log metrics
126
- self.train_loss(loss)
127
- self.train_acc(preds, targets)
128
- self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True)
129
- self.log("train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True)
130
-
131
- # return loss or backpropagation will fail
132
- return loss
133
-
134
- def on_train_epoch_end(self) -> None:
135
- "Lightning hook that is called when a training epoch ends."
136
- pass
137
-
138
- def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
139
- """Perform a single validation step on a batch of data from the validation set.
140
-
141
- :param batch: A batch of data (a tuple) containing the input tensor of images and target
142
- labels.
143
- :param batch_idx: The index of the current batch.
144
- """
145
- loss, preds, targets = self.model_step(batch)
146
-
147
- # update and log metrics
148
- self.val_loss(loss)
149
- self.val_acc(preds, targets)
150
- self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)
151
- self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
152
-
153
- def on_validation_epoch_end(self) -> None:
154
- "Lightning hook that is called when a validation epoch ends."
155
- acc = self.val_acc.compute() # get current val acc
156
- self.val_acc_best(acc) # update best so far val acc
157
- # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object
158
- # otherwise metric would be reset by lightning after each epoch
159
- self.log("val/acc_best", self.val_acc_best.compute(), sync_dist=True, prog_bar=True)
160
-
161
- def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
162
- """Perform a single test step on a batch of data from the test set.
163
-
164
- :param batch: A batch of data (a tuple) containing the input tensor of images and target
165
- labels.
166
- :param batch_idx: The index of the current batch.
167
- """
168
- loss, preds, targets = self.model_step(batch)
169
-
170
- # update and log metrics
171
- self.test_loss(loss)
172
- self.test_acc(preds, targets)
173
- self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True)
174
- self.log("test/acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True)
175
-
176
- def on_test_epoch_end(self) -> None:
177
- """Lightning hook that is called when a test epoch ends."""
178
- pass
179
-
180
- def setup(self, stage: str) -> None:
181
- """Lightning hook that is called at the beginning of fit (train + validate), validate,
182
- test, or predict.
183
-
184
- This is a good hook when you need to build models dynamically or adjust something about
185
- them. This hook is called on every process when using DDP.
186
-
187
- :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
188
- """
189
- if self.hparams.compile and stage == "fit":
190
- self.net = torch.compile(self.net)
191
-
192
- def configure_optimizers(self) -> Dict[str, Any]:
193
- """Choose what optimizers and learning-rate schedulers to use in your optimization.
194
- Normally you'd need one. But in the case of GANs or similar you might have multiple.
195
-
196
- Examples:
197
- https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
198
-
199
- :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training.
200
- """
201
- optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())
202
- if self.hparams.scheduler is not None:
203
- scheduler = self.hparams.scheduler(optimizer=optimizer)
204
- return {
205
- "optimizer": optimizer,
206
- "lr_scheduler": {
207
- "scheduler": scheduler,
208
- "monitor": "val/loss",
209
- "interval": "epoch",
210
- "frequency": 1,
211
- },
212
- }
213
- return {"optimizer": optimizer}
214
-
215
-
216
- if __name__ == "__main__":
217
- _ = MNISTLitModule(None, None, None, None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/train.py CHANGED
@@ -38,6 +38,8 @@ from src.utils import (
38
 
39
  log = RankedLogger(__name__, rank_zero_only=True)
40
 
 
 
41
 
42
  @task_wrapper
43
  def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
@@ -67,7 +69,9 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
67
  logger: List[Logger] = instantiate_loggers(cfg.get("logger"))
68
 
69
  log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
70
- trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger)
 
 
71
 
72
  object_dict = {
73
  "cfg": cfg,
 
38
 
39
  log = RankedLogger(__name__, rank_zero_only=True)
40
 
41
+ torch.set_float32_matmul_precision("medium")
42
+
43
 
44
  @task_wrapper
45
  def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
 
69
  logger: List[Logger] = instantiate_loggers(cfg.get("logger"))
70
 
71
  log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
72
+ trainer: Trainer = hydra.utils.instantiate(
73
+ cfg.trainer, callbacks=callbacks, logger=logger
74
+ )
75
 
76
  object_dict = {
77
  "cfg": cfg,
test_bert.ipynb ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "/home/qninh/miniconda3/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13
+ " from .autonotebook import tqdm as notebook_tqdm\n"
14
+ ]
15
+ }
16
+ ],
17
+ "source": [
18
+ "from src.data.mixed_datamodule import MixedDataModule"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": 2,
24
+ "metadata": {},
25
+ "outputs": [],
26
+ "source": [
27
+ "datamodule = MixedDataModule(dataset_path=\"./data/mixed\", batch_size=32, num_workers=4, bert_model=\"bert-base-uncased\", tool_capacity=16)"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": 3,
33
+ "metadata": {},
34
+ "outputs": [],
35
+ "source": [
36
+ "datamodule.setup(stage=\"fit\")"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "execution_count": 4,
42
+ "metadata": {},
43
+ "outputs": [],
44
+ "source": [
45
+ "train_dataloader = datamodule.train_dataloader()"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "execution_count": 5,
51
+ "metadata": {},
52
+ "outputs": [
53
+ {
54
+ "data": {
55
+ "text/plain": [
56
+ "{'instruction': torch.Size([32, 128]),\n",
57
+ " 'instruction_mask': torch.Size([32, 128]),\n",
58
+ " 'tool_desc_emb': torch.Size([32, 128]),\n",
59
+ " 'tool_desc_mask': torch.Size([32, 128])}"
60
+ ]
61
+ },
62
+ "execution_count": 5,
63
+ "metadata": {},
64
+ "output_type": "execute_result"
65
+ }
66
+ ],
67
+ "source": [
68
+ "# first sample\n",
69
+ "batch = next(iter(train_dataloader))\n",
70
+ "{\n",
71
+ " key: value.shape\n",
72
+ " for key, value in batch.items()\n",
73
+ "}"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": 6,
79
+ "metadata": {},
80
+ "outputs": [],
81
+ "source": [
82
+ "val_dataloader = datamodule.val_dataloader()"
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "code",
87
+ "execution_count": 7,
88
+ "metadata": {},
89
+ "outputs": [
90
+ {
91
+ "data": {
92
+ "text/plain": [
93
+ "{'instruction': torch.Size([32, 128]),\n",
94
+ " 'instruction_mask': torch.Size([32, 128]),\n",
95
+ " 'tool_desc_emb': torch.Size([32, 16, 128]),\n",
96
+ " 'tool_desc_mask': torch.Size([32, 16, 128])}"
97
+ ]
98
+ },
99
+ "execution_count": 7,
100
+ "metadata": {},
101
+ "output_type": "execute_result"
102
+ }
103
+ ],
104
+ "source": [
105
+ "# first sample\n",
106
+ "batch = next(iter(val_dataloader))\n",
107
+ "{\n",
108
+ " key: value.shape\n",
109
+ " for key, value in batch.items()\n",
110
+ "}"
111
+ ]
112
+ },
113
+ {
114
+ "cell_type": "code",
115
+ "execution_count": null,
116
+ "metadata": {},
117
+ "outputs": [],
118
+ "source": []
119
+ }
120
+ ],
121
+ "metadata": {
122
+ "kernelspec": {
123
+ "display_name": "swim",
124
+ "language": "python",
125
+ "name": "python3"
126
+ },
127
+ "language_info": {
128
+ "codemirror_mode": {
129
+ "name": "ipython",
130
+ "version": 3
131
+ },
132
+ "file_extension": ".py",
133
+ "mimetype": "text/x-python",
134
+ "name": "python",
135
+ "nbconvert_exporter": "python",
136
+ "pygments_lexer": "ipython3",
137
+ "version": "3.10.14"
138
+ }
139
+ },
140
+ "nbformat": 4,
141
+ "nbformat_minor": 2
142
+ }