Spaces:
Sleeping
Sleeping
Sergei
commited on
Commit
·
8f186c9
1
Parent(s):
a2644fb
Переобученная модель + файл саспознания переделан, для запуска из другого файла
Browse filesFormer-commit-id: dfb192eb464f22cd65b9f947a889aecd1d8d1243
Former-commit-id: 4e9ba30f7f67980b897d1269e0395aec18418448
- Test_photo/.DS_Store +0 -0
- Test_photo/1.jpg +0 -0
- Test_photo/2.jpg +0 -0
- Test_photo/3.jpg +0 -0
- Test_photo/4.jpg +0 -0
- Test_photo/5.jpg +0 -0
- Test_photo/6.jpg +0 -0
- cat.csv +1 -0
- check_photo.py +25 -0
- check_photo_model_init.py +17 -0
- check_photo_model_retrain.py +154 -0
- pickle_model.pkl.REMOVED.git-id +1 -1
- test_check_photo.py +21 -0
Test_photo/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
Test_photo/1.jpg
ADDED
Test_photo/2.jpg
ADDED
Test_photo/3.jpg
ADDED
Test_photo/4.jpg
ADDED
Test_photo/5.jpg
ADDED
Test_photo/6.jpg
ADDED
cat.csv
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Мемориальная квартира Пушкина на Арбате,Новый Арбат,Памятник Александру Пушкину и Наталье Гончаровой,Памятники Булату Окуджаве в Москве,"Художественный (кинотеатр, Москва)",Центральный Дом актёра имени А. А. Яблочкиной
|
check_photo.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision.transforms as transforms
|
3 |
+
from PIL import Image
|
4 |
+
import torchvision
|
5 |
+
|
6 |
+
# Запуск модели для распознания фото
|
7 |
+
def check_photo1(model, categorias, photo):
|
8 |
+
# Тот же формат фото, что и при обучении
|
9 |
+
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
10 |
+
std=[0.229, 0.224, 0.225])
|
11 |
+
preprocess = transforms.Compose([
|
12 |
+
transforms.Resize([70, 70]),
|
13 |
+
transforms.RandomHorizontalFlip(),
|
14 |
+
transforms.RandomAutocontrast(),
|
15 |
+
transforms.RandomEqualize(),
|
16 |
+
transforms.ToTensor(),
|
17 |
+
normalize
|
18 |
+
])
|
19 |
+
batch = preprocess(photo).unsqueeze(0)
|
20 |
+
prediction = model(batch).squeeze(0).softmax(0)
|
21 |
+
class_id = prediction.argmax().item()
|
22 |
+
score = prediction[class_id].item()
|
23 |
+
return categorias[class_id], score
|
24 |
+
|
25 |
+
|
check_photo_model_init.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
import csv
|
3 |
+
|
4 |
+
# Файл инициализации модели
|
5 |
+
def init_model():
|
6 |
+
# Загрузить модели из файла
|
7 |
+
pkl_filename = "pickle_model.pkl"
|
8 |
+
with open(pkl_filename, 'rb') as file:
|
9 |
+
model = pickle.load(file)
|
10 |
+
|
11 |
+
# Считывание категорий
|
12 |
+
file = open("cat.csv", "r")
|
13 |
+
cat1 = list(csv.reader(file, delimiter=","))
|
14 |
+
categorias = cat1[0]
|
15 |
+
file.close()
|
16 |
+
model.eval()
|
17 |
+
return model, categorias
|
check_photo_model_retrain.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
import torchvision
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import torchvision.transforms as transforms
|
8 |
+
import shutil
|
9 |
+
import time
|
10 |
+
import xml.etree.ElementTree as et
|
11 |
+
import pickle
|
12 |
+
import csv
|
13 |
+
|
14 |
+
from tqdm import tqdm
|
15 |
+
from PIL import Image
|
16 |
+
from torchvision import models
|
17 |
+
from torch.utils.data import DataLoader
|
18 |
+
from torchvision.datasets import ImageFolder
|
19 |
+
# Размер одного пакета
|
20 |
+
BATCH_SIZE = 32
|
21 |
+
|
22 |
+
use_gpu = torch.cuda.is_available()
|
23 |
+
device = 'cuda' if use_gpu else 'cpu'
|
24 |
+
print('Connected device:', device)
|
25 |
+
|
26 |
+
# Датасет для тренировки
|
27 |
+
train_dataset = ImageFolder(
|
28 |
+
root='Data/Train'
|
29 |
+
)
|
30 |
+
# Датасет для проверки
|
31 |
+
valid_dataset = ImageFolder(
|
32 |
+
root='Data/Valid'
|
33 |
+
)
|
34 |
+
|
35 |
+
# augmentations (ухудшение качество чтобы не было переобучения)
|
36 |
+
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
37 |
+
std=[0.229, 0.224, 0.225])
|
38 |
+
train_dataset.transform = transforms.Compose([
|
39 |
+
transforms.Resize([70, 70]),
|
40 |
+
transforms.RandomHorizontalFlip(),
|
41 |
+
transforms.RandomAutocontrast(),
|
42 |
+
transforms.RandomEqualize(),
|
43 |
+
transforms.ToTensor(),
|
44 |
+
normalize
|
45 |
+
])
|
46 |
+
|
47 |
+
valid_dataset.transform = transforms.Compose([
|
48 |
+
transforms.Resize([70, 70]),
|
49 |
+
transforms.ToTensor(),
|
50 |
+
normalize
|
51 |
+
])
|
52 |
+
|
53 |
+
# Определение выборки для обучения
|
54 |
+
train_loader = DataLoader(
|
55 |
+
train_dataset, batch_size=BATCH_SIZE,
|
56 |
+
shuffle=True
|
57 |
+
)
|
58 |
+
# Определение выборки для проверки
|
59 |
+
valid_loader = DataLoader(
|
60 |
+
valid_dataset, batch_size=BATCH_SIZE,
|
61 |
+
shuffle=False
|
62 |
+
)
|
63 |
+
|
64 |
+
# Указание на используемую модель
|
65 |
+
def google(): # pretrained=True для tensorflow
|
66 |
+
model = models.googlenet(weights=models.GoogLeNet_Weights.IMAGENET1K_V1)
|
67 |
+
# Добавление линейного (выходного) слоя на основании которого идет дообучение
|
68 |
+
model.fc = torch.nn.Linear(1024, len(train_dataset.classes))
|
69 |
+
for param in model.parameters():
|
70 |
+
param.requires_grad = True
|
71 |
+
# Заморозка весов т.к. при переобучении модели они должны быть постоянны, а меняться будет только последний слой
|
72 |
+
model.inception3a.requires_grad = False
|
73 |
+
model.inception3b.requires_grad = False
|
74 |
+
model.inception4a.requires_grad = False
|
75 |
+
model.inception4b.requires_grad = False
|
76 |
+
model.inception4c.requires_grad = False
|
77 |
+
model.inception4d.requires_grad = False
|
78 |
+
model.inception4e.requires_grad = False
|
79 |
+
return model
|
80 |
+
|
81 |
+
# Функция обучения модели. Epoch - количество итераций обучения (прогонов по нейросети)
|
82 |
+
def train(model, optimizer, train_loader, val_loader, epoch=10):
|
83 |
+
loss_train, acc_train = [], []
|
84 |
+
loss_valid, acc_valid = [], []
|
85 |
+
# tqdm - прогресс бар
|
86 |
+
for epoch in tqdm(range(epoch)):
|
87 |
+
# Ошибки
|
88 |
+
losses, equals = [], []
|
89 |
+
torch.set_grad_enabled(True)
|
90 |
+
|
91 |
+
# Train. Обучение. В цикле проходится по картинкам и оптимизируются потери
|
92 |
+
model.train()
|
93 |
+
for i, (image, target) in enumerate(train_loader):
|
94 |
+
image = image.to(device)
|
95 |
+
target = target.to(device)
|
96 |
+
output = model(image)
|
97 |
+
loss = criterion(output,target)
|
98 |
+
|
99 |
+
losses.append(loss.item())
|
100 |
+
equals.extend(
|
101 |
+
[x.item() for x in torch.argmax(output, 1) == target])
|
102 |
+
|
103 |
+
optimizer.zero_grad()
|
104 |
+
loss.backward()
|
105 |
+
optimizer.step()
|
106 |
+
# Метрики отображающие резултитаты обучения модели
|
107 |
+
loss_train.append(np.mean(losses))
|
108 |
+
acc_train.append(np.mean(equals))
|
109 |
+
losses, equals = [], []
|
110 |
+
torch.set_grad_enabled(False)
|
111 |
+
|
112 |
+
# Validate. Оценка качества обучения
|
113 |
+
model.eval()
|
114 |
+
for i , (image, target) in enumerate(valid_loader):
|
115 |
+
image = image.to(device)
|
116 |
+
target = target.to(device)
|
117 |
+
|
118 |
+
output = model(image)
|
119 |
+
loss = criterion(output,target)
|
120 |
+
|
121 |
+
losses.append(loss.item())
|
122 |
+
equals.extend(
|
123 |
+
[y.item() for y in torch.argmax(output, 1) == target])
|
124 |
+
|
125 |
+
loss_valid.append(np.mean(losses))
|
126 |
+
acc_valid.append(np.mean(equals))
|
127 |
+
|
128 |
+
return loss_train, acc_train, loss_valid, acc_valid
|
129 |
+
|
130 |
+
criterion = torch.nn.CrossEntropyLoss()
|
131 |
+
criterion = criterion.to(device)
|
132 |
+
|
133 |
+
model = google()
|
134 |
+
print('Model: GoogLeNet\n')
|
135 |
+
|
136 |
+
# оптимайзер - отвечает за поиск и подбор оптимальных весов
|
137 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
138 |
+
model = model.to(device)
|
139 |
+
|
140 |
+
loss_train, acc_train, loss_valid, acc_valid = train(
|
141 |
+
model, optimizer, train_loader, valid_loader, 30)
|
142 |
+
print('acc_train:', acc_train, '\nacc_valid:', acc_valid)
|
143 |
+
|
144 |
+
# Сохранение модели в текущую рабочую директорию
|
145 |
+
pkl_filename = "pickle_model.pkl"
|
146 |
+
with open(pkl_filename, 'wb') as file:
|
147 |
+
pickle.dump(model, file)
|
148 |
+
|
149 |
+
# Категории. Получаются из имен папок
|
150 |
+
print(train_dataset.classes)
|
151 |
+
# Экспорт категорий в CSV
|
152 |
+
with open('cat.csv', 'w', newline='') as file:
|
153 |
+
writer = csv.writer(file)
|
154 |
+
writer.writerow(train_dataset.classes)
|
pickle_model.pkl.REMOVED.git-id
CHANGED
@@ -1 +1 @@
|
|
1 |
-
|
|
|
1 |
+
ad5ecffeba11c98262cbe84ad38397d2accd8892
|
test_check_photo.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from check_photo import *
|
2 |
+
from check_photo_model_init import *
|
3 |
+
|
4 |
+
model, cat = init_model()
|
5 |
+
|
6 |
+
Puskin_pamiatnik = Image.open("Data/Test_photo/1.jpg")
|
7 |
+
Nov_arbat1 = Image.open("Data/Test_photo/2.jpg")
|
8 |
+
Pushkin_dom1 = Image.open("Data/Test_photo/3.jpg")
|
9 |
+
CDA1 = Image.open("Data/Test_photo/4.jpg")
|
10 |
+
Okudjava1 = Image.open("Data/Test_photo/5.jpg")
|
11 |
+
Kinoteatr = Image.open("Data/Test_photo/6.jpg")
|
12 |
+
test_photos_dict = {'Puskin_pamiatnik':Puskin_pamiatnik,
|
13 |
+
'Nov_arbat1':Nov_arbat1,
|
14 |
+
'Pushkin_dom1': Pushkin_dom1,
|
15 |
+
'CDA1': CDA1,
|
16 |
+
'Okudjava1': Okudjava1,
|
17 |
+
'Kinoteatr': Kinoteatr,
|
18 |
+
}
|
19 |
+
for name in test_photos_dict:
|
20 |
+
res_cat, res_score = check_photo1(model, cat, test_photos_dict[name])
|
21 |
+
print(f"{res_cat}: {100 * res_score:.1f}%", "right answer", name)
|