Spaces:
Running
Running
Zai
commited on
Commit
·
18943e5
1
Parent(s):
a88d601
robot.py and dataset starting
Browse files
headshot/headshot.py
CHANGED
@@ -1,9 +1,27 @@
|
|
1 |
import torch
|
2 |
from torch import nn
|
|
|
3 |
|
4 |
class Headshot(nn.Module):
|
5 |
def __init__(self):
|
6 |
super().__init__()
|
|
|
|
|
7 |
|
8 |
def forward(self,x):
|
9 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
from torch import nn
|
3 |
+
from data_prep import FaceDataset
|
4 |
|
5 |
class Headshot(nn.Module):
|
6 |
def __init__(self):
|
7 |
super().__init__()
|
8 |
+
self.dataset = FaceDataset()
|
9 |
+
self.num_epoch = 20
|
10 |
|
11 |
def forward(self,x):
|
12 |
+
pass
|
13 |
+
|
14 |
+
def train(self):
|
15 |
+
for epoch in range(self.num_epoch):
|
16 |
+
for i,(image,label) in enumerate(self.dataset):
|
17 |
+
pass
|
18 |
+
|
19 |
+
def predict(self,image):
|
20 |
+
points = self.forward(image)
|
21 |
+
|
22 |
+
|
23 |
+
def load_pretrain(self,name=""):
|
24 |
+
pretrained = torch.load(pretrained)
|
25 |
+
self.load_state_dict(pretrained)
|
26 |
+
|
27 |
+
|
headshot/robot.py
CHANGED
@@ -1 +1,16 @@
|
|
1 |
-
# to connect with some physical machine
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# to connect with some physical machine
|
2 |
+
from pyrobot import Robot
|
3 |
+
|
4 |
+
# Initialize the robot
|
5 |
+
robot = Robot('locobot')
|
6 |
+
|
7 |
+
def move_forward(distance):
|
8 |
+
# Move the robot forward
|
9 |
+
robot.base.go_to_relative([distance, 0, 0])
|
10 |
+
|
11 |
+
if __name__ == "__main__":
|
12 |
+
# Define the displacement (in meters) for the robot to move forward
|
13 |
+
forward_distance = 0.5 # Move 0.5 meters forward
|
14 |
+
|
15 |
+
# Move the robot forward
|
16 |
+
move_forward(forward_distance)
|
notebooks/.ipynb_checkpoints/detection_pytorch-checkpoint.ipynb
DELETED
@@ -1,190 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "code",
|
5 |
-
"execution_count": 14,
|
6 |
-
"id": "13343cbd-bede-41d9-9506-08ed04e66cf6",
|
7 |
-
"metadata": {},
|
8 |
-
"outputs": [
|
9 |
-
{
|
10 |
-
"name": "stderr",
|
11 |
-
"output_type": "stream",
|
12 |
-
"text": [
|
13 |
-
"C:\\Users\\Myo Win Zaw\\.conda\\envs\\ai_env\\lib\\site-packages\\torchvision\\io\\image.py:13: UserWarning: Failed to load image Python extension: '[WinError 127] The specified procedure could not be found'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?\n",
|
14 |
-
" warn(\n"
|
15 |
-
]
|
16 |
-
}
|
17 |
-
],
|
18 |
-
"source": [
|
19 |
-
"import torch\n",
|
20 |
-
"import torch.nn as nn\n",
|
21 |
-
"import torch.optim as optim\n",
|
22 |
-
"from torch.utils.data import DataLoader, Dataset\n",
|
23 |
-
"from torchvision import transforms, datasets\n",
|
24 |
-
"from tqdm import tqdm"
|
25 |
-
]
|
26 |
-
},
|
27 |
-
{
|
28 |
-
"cell_type": "code",
|
29 |
-
"execution_count": 15,
|
30 |
-
"id": "31812e9b-1b04-44fc-891e-1e77c866ff75",
|
31 |
-
"metadata": {},
|
32 |
-
"outputs": [],
|
33 |
-
"source": [
|
34 |
-
"# declaration"
|
35 |
-
]
|
36 |
-
},
|
37 |
-
{
|
38 |
-
"cell_type": "code",
|
39 |
-
"execution_count": 18,
|
40 |
-
"id": "0c3b7de2-b89c-47f8-83d9-d20f27af390a",
|
41 |
-
"metadata": {},
|
42 |
-
"outputs": [],
|
43 |
-
"source": [
|
44 |
-
"# Load image datas\n",
|
45 |
-
"\n",
|
46 |
-
"device = 'cpu'"
|
47 |
-
]
|
48 |
-
},
|
49 |
-
{
|
50 |
-
"cell_type": "code",
|
51 |
-
"execution_count": 19,
|
52 |
-
"id": "58383873-e2c3-4683-ba04-72bad7a6d773",
|
53 |
-
"metadata": {},
|
54 |
-
"outputs": [],
|
55 |
-
"source": [
|
56 |
-
"# dataset \n",
|
57 |
-
"class FaceDataset(Dataset):\n",
|
58 |
-
" def __init__(self,data,labels,transforms=None):\n",
|
59 |
-
" self.tranforms = tranforms\n",
|
60 |
-
" self.data = x_data\n",
|
61 |
-
" self.labels = y_labels\n",
|
62 |
-
"\n",
|
63 |
-
" def __len__(self):\n",
|
64 |
-
" return len(self.data)\n",
|
65 |
-
"\n",
|
66 |
-
" def __getitem__(self, idx):\n",
|
67 |
-
" # Load and preprocess the image at the given index\n",
|
68 |
-
" image = self.data[idx]\n",
|
69 |
-
" label = self.labels[idx]\n",
|
70 |
-
" \n",
|
71 |
-
" if self.transform:\n",
|
72 |
-
" image = self.transform(image)\n",
|
73 |
-
" return image,label"
|
74 |
-
]
|
75 |
-
},
|
76 |
-
{
|
77 |
-
"cell_type": "code",
|
78 |
-
"execution_count": 41,
|
79 |
-
"id": "8d6e52a5-b80a-40fd-9a58-40255efba005",
|
80 |
-
"metadata": {},
|
81 |
-
"outputs": [],
|
82 |
-
"source": [
|
83 |
-
"class Detector(nn.Module):\n",
|
84 |
-
" def __init__(self):\n",
|
85 |
-
" super().__init__()\n",
|
86 |
-
" self.conv1 = nn.Conv2d(72,64,4)\n",
|
87 |
-
"\n",
|
88 |
-
" def forward(self):\n",
|
89 |
-
" pass"
|
90 |
-
]
|
91 |
-
},
|
92 |
-
{
|
93 |
-
"cell_type": "code",
|
94 |
-
"execution_count": 42,
|
95 |
-
"id": "298e5cd7-d247-437d-ae69-271dbbdfdf03",
|
96 |
-
"metadata": {},
|
97 |
-
"outputs": [],
|
98 |
-
"source": [
|
99 |
-
"# optimization\n",
|
100 |
-
"lr = 1e-3\n",
|
101 |
-
"# model = Detector().to(device)\n",
|
102 |
-
"# optimizer = torch.optim.Adam(model)\n",
|
103 |
-
"loss_fn = nn.CrossEntropyLoss()\n",
|
104 |
-
"\n",
|
105 |
-
"num_epochs = 50\n"
|
106 |
-
]
|
107 |
-
},
|
108 |
-
{
|
109 |
-
"cell_type": "code",
|
110 |
-
"execution_count": 43,
|
111 |
-
"id": "89877e6b-9682-4d00-acae-94020722e7e1",
|
112 |
-
"metadata": {},
|
113 |
-
"outputs": [
|
114 |
-
{
|
115 |
-
"name": "stdout",
|
116 |
-
"output_type": "stream",
|
117 |
-
"text": [
|
118 |
-
"xi\n"
|
119 |
-
]
|
120 |
-
},
|
121 |
-
{
|
122 |
-
"ename": "TypeError",
|
123 |
-
"evalue": "CrossEntropyLoss.forward() missing 2 required positional arguments: 'input' and 'target'",
|
124 |
-
"output_type": "error",
|
125 |
-
"traceback": [
|
126 |
-
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
127 |
-
"\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)",
|
128 |
-
"Cell \u001b[1;32mIn[43], line 9\u001b[0m\n\u001b[0;32m 6\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(dummy_dataset):\n\u001b[0;32m 7\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mxi\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m----> 9\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mloss_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
|
129 |
-
"File \u001b[1;32m~\\.conda\\envs\\ai_env\\lib\\site-packages\\torch\\nn\\modules\\module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
|
130 |
-
"File \u001b[1;32m~\\.conda\\envs\\ai_env\\lib\\site-packages\\torch\\nn\\modules\\module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
131 |
-
"\u001b[1;31mTypeError\u001b[0m: CrossEntropyLoss.forward() missing 2 required positional arguments: 'input' and 'target'"
|
132 |
-
]
|
133 |
-
}
|
134 |
-
],
|
135 |
-
"source": [
|
136 |
-
"# train model\n",
|
137 |
-
"\n",
|
138 |
-
"dummy_dataset= ['hello']\n",
|
139 |
-
"\n",
|
140 |
-
"for i in range(num_epochs):\n",
|
141 |
-
" for img,label in enumerate(dummy_dataset):\n",
|
142 |
-
" optimizer.zero_grad()\n",
|
143 |
-
" outputs = model(img,label)\n",
|
144 |
-
" loss = loss_fn(outputs,labels)\n",
|
145 |
-
" loss.backward()\n",
|
146 |
-
" optimizer.step()\n",
|
147 |
-
" print(f\"epoch {i} done\")"
|
148 |
-
]
|
149 |
-
},
|
150 |
-
{
|
151 |
-
"cell_type": "code",
|
152 |
-
"execution_count": null,
|
153 |
-
"id": "7a3e59e6-7701-4e11-8d29-c6e5468d63ab",
|
154 |
-
"metadata": {},
|
155 |
-
"outputs": [],
|
156 |
-
"source": [
|
157 |
-
"# eval model\n"
|
158 |
-
]
|
159 |
-
},
|
160 |
-
{
|
161 |
-
"cell_type": "code",
|
162 |
-
"execution_count": null,
|
163 |
-
"id": "c0ddaf7e-185b-4190-b491-120008d1e1ea",
|
164 |
-
"metadata": {},
|
165 |
-
"outputs": [],
|
166 |
-
"source": []
|
167 |
-
}
|
168 |
-
],
|
169 |
-
"metadata": {
|
170 |
-
"kernelspec": {
|
171 |
-
"display_name": "Python 3 (ipykernel)",
|
172 |
-
"language": "python",
|
173 |
-
"name": "python3"
|
174 |
-
},
|
175 |
-
"language_info": {
|
176 |
-
"codemirror_mode": {
|
177 |
-
"name": "ipython",
|
178 |
-
"version": 3
|
179 |
-
},
|
180 |
-
"file_extension": ".py",
|
181 |
-
"mimetype": "text/x-python",
|
182 |
-
"name": "python",
|
183 |
-
"nbconvert_exporter": "python",
|
184 |
-
"pygments_lexer": "ipython3",
|
185 |
-
"version": "3.10.13"
|
186 |
-
}
|
187 |
-
},
|
188 |
-
"nbformat": 4,
|
189 |
-
"nbformat_minor": 5
|
190 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
notebooks/detection_pytorch.ipynb
DELETED
@@ -1,190 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "code",
|
5 |
-
"execution_count": 14,
|
6 |
-
"id": "13343cbd-bede-41d9-9506-08ed04e66cf6",
|
7 |
-
"metadata": {},
|
8 |
-
"outputs": [
|
9 |
-
{
|
10 |
-
"name": "stderr",
|
11 |
-
"output_type": "stream",
|
12 |
-
"text": [
|
13 |
-
"C:\\Users\\Myo Win Zaw\\.conda\\envs\\ai_env\\lib\\site-packages\\torchvision\\io\\image.py:13: UserWarning: Failed to load image Python extension: '[WinError 127] The specified procedure could not be found'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?\n",
|
14 |
-
" warn(\n"
|
15 |
-
]
|
16 |
-
}
|
17 |
-
],
|
18 |
-
"source": [
|
19 |
-
"import torch\n",
|
20 |
-
"import torch.nn as nn\n",
|
21 |
-
"import torch.optim as optim\n",
|
22 |
-
"from torch.utils.data import DataLoader, Dataset\n",
|
23 |
-
"from torchvision import transforms, datasets\n",
|
24 |
-
"from tqdm import tqdm"
|
25 |
-
]
|
26 |
-
},
|
27 |
-
{
|
28 |
-
"cell_type": "code",
|
29 |
-
"execution_count": 15,
|
30 |
-
"id": "31812e9b-1b04-44fc-891e-1e77c866ff75",
|
31 |
-
"metadata": {},
|
32 |
-
"outputs": [],
|
33 |
-
"source": [
|
34 |
-
"# declaration"
|
35 |
-
]
|
36 |
-
},
|
37 |
-
{
|
38 |
-
"cell_type": "code",
|
39 |
-
"execution_count": 18,
|
40 |
-
"id": "0c3b7de2-b89c-47f8-83d9-d20f27af390a",
|
41 |
-
"metadata": {},
|
42 |
-
"outputs": [],
|
43 |
-
"source": [
|
44 |
-
"# Load image datas\n",
|
45 |
-
"\n",
|
46 |
-
"device = 'cpu'"
|
47 |
-
]
|
48 |
-
},
|
49 |
-
{
|
50 |
-
"cell_type": "code",
|
51 |
-
"execution_count": 19,
|
52 |
-
"id": "58383873-e2c3-4683-ba04-72bad7a6d773",
|
53 |
-
"metadata": {},
|
54 |
-
"outputs": [],
|
55 |
-
"source": [
|
56 |
-
"# dataset \n",
|
57 |
-
"class FaceDataset(Dataset):\n",
|
58 |
-
" def __init__(self,data,labels,transforms=None):\n",
|
59 |
-
" self.tranforms = tranforms\n",
|
60 |
-
" self.data = x_data\n",
|
61 |
-
" self.labels = y_labels\n",
|
62 |
-
"\n",
|
63 |
-
" def __len__(self):\n",
|
64 |
-
" return len(self.data)\n",
|
65 |
-
"\n",
|
66 |
-
" def __getitem__(self, idx):\n",
|
67 |
-
" # Load and preprocess the image at the given index\n",
|
68 |
-
" image = self.data[idx]\n",
|
69 |
-
" label = self.labels[idx]\n",
|
70 |
-
" \n",
|
71 |
-
" if self.transform:\n",
|
72 |
-
" image = self.transform(image)\n",
|
73 |
-
" return image,label"
|
74 |
-
]
|
75 |
-
},
|
76 |
-
{
|
77 |
-
"cell_type": "code",
|
78 |
-
"execution_count": 41,
|
79 |
-
"id": "8d6e52a5-b80a-40fd-9a58-40255efba005",
|
80 |
-
"metadata": {},
|
81 |
-
"outputs": [],
|
82 |
-
"source": [
|
83 |
-
"class Detector(nn.Module):\n",
|
84 |
-
" def __init__(self):\n",
|
85 |
-
" super().__init__()\n",
|
86 |
-
" self.conv1 = nn.Conv2d(72,64,4)\n",
|
87 |
-
"\n",
|
88 |
-
" def forward(self):\n",
|
89 |
-
" pass"
|
90 |
-
]
|
91 |
-
},
|
92 |
-
{
|
93 |
-
"cell_type": "code",
|
94 |
-
"execution_count": 42,
|
95 |
-
"id": "298e5cd7-d247-437d-ae69-271dbbdfdf03",
|
96 |
-
"metadata": {},
|
97 |
-
"outputs": [],
|
98 |
-
"source": [
|
99 |
-
"# optimization\n",
|
100 |
-
"lr = 1e-3\n",
|
101 |
-
"# model = Detector().to(device)\n",
|
102 |
-
"# optimizer = torch.optim.Adam(model)\n",
|
103 |
-
"loss_fn = nn.CrossEntropyLoss()\n",
|
104 |
-
"\n",
|
105 |
-
"num_epochs = 50\n"
|
106 |
-
]
|
107 |
-
},
|
108 |
-
{
|
109 |
-
"cell_type": "code",
|
110 |
-
"execution_count": 43,
|
111 |
-
"id": "89877e6b-9682-4d00-acae-94020722e7e1",
|
112 |
-
"metadata": {},
|
113 |
-
"outputs": [
|
114 |
-
{
|
115 |
-
"name": "stdout",
|
116 |
-
"output_type": "stream",
|
117 |
-
"text": [
|
118 |
-
"xi\n"
|
119 |
-
]
|
120 |
-
},
|
121 |
-
{
|
122 |
-
"ename": "TypeError",
|
123 |
-
"evalue": "CrossEntropyLoss.forward() missing 2 required positional arguments: 'input' and 'target'",
|
124 |
-
"output_type": "error",
|
125 |
-
"traceback": [
|
126 |
-
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
127 |
-
"\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)",
|
128 |
-
"Cell \u001b[1;32mIn[43], line 9\u001b[0m\n\u001b[0;32m 6\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(dummy_dataset):\n\u001b[0;32m 7\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mxi\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m----> 9\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mloss_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
|
129 |
-
"File \u001b[1;32m~\\.conda\\envs\\ai_env\\lib\\site-packages\\torch\\nn\\modules\\module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
|
130 |
-
"File \u001b[1;32m~\\.conda\\envs\\ai_env\\lib\\site-packages\\torch\\nn\\modules\\module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
131 |
-
"\u001b[1;31mTypeError\u001b[0m: CrossEntropyLoss.forward() missing 2 required positional arguments: 'input' and 'target'"
|
132 |
-
]
|
133 |
-
}
|
134 |
-
],
|
135 |
-
"source": [
|
136 |
-
"# train model\n",
|
137 |
-
"\n",
|
138 |
-
"dummy_dataset= ['hello']\n",
|
139 |
-
"\n",
|
140 |
-
"for i in range(num_epochs):\n",
|
141 |
-
" for img,label in enumerate(dummy_dataset):\n",
|
142 |
-
" optimizer.zero_grad()\n",
|
143 |
-
" outputs = model(img,label)\n",
|
144 |
-
" loss = loss_fn(outputs,labels)\n",
|
145 |
-
" loss.backward()\n",
|
146 |
-
" optimizer.step()\n",
|
147 |
-
" print(f\"epoch {i} done\")"
|
148 |
-
]
|
149 |
-
},
|
150 |
-
{
|
151 |
-
"cell_type": "code",
|
152 |
-
"execution_count": null,
|
153 |
-
"id": "7a3e59e6-7701-4e11-8d29-c6e5468d63ab",
|
154 |
-
"metadata": {},
|
155 |
-
"outputs": [],
|
156 |
-
"source": [
|
157 |
-
"# eval model\n"
|
158 |
-
]
|
159 |
-
},
|
160 |
-
{
|
161 |
-
"cell_type": "code",
|
162 |
-
"execution_count": null,
|
163 |
-
"id": "c0ddaf7e-185b-4190-b491-120008d1e1ea",
|
164 |
-
"metadata": {},
|
165 |
-
"outputs": [],
|
166 |
-
"source": []
|
167 |
-
}
|
168 |
-
],
|
169 |
-
"metadata": {
|
170 |
-
"kernelspec": {
|
171 |
-
"display_name": "Python 3 (ipykernel)",
|
172 |
-
"language": "python",
|
173 |
-
"name": "python3"
|
174 |
-
},
|
175 |
-
"language_info": {
|
176 |
-
"codemirror_mode": {
|
177 |
-
"name": "ipython",
|
178 |
-
"version": 3
|
179 |
-
},
|
180 |
-
"file_extension": ".py",
|
181 |
-
"mimetype": "text/x-python",
|
182 |
-
"name": "python",
|
183 |
-
"nbconvert_exporter": "python",
|
184 |
-
"pygments_lexer": "ipython3",
|
185 |
-
"version": "3.10.13"
|
186 |
-
}
|
187 |
-
},
|
188 |
-
"nbformat": 4,
|
189 |
-
"nbformat_minor": 5
|
190 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|