Sergei commited on
Commit
8f186c9
·
1 Parent(s): a2644fb

Переобученная модель + файл саспознания переделан, для запуска из другого файла

Browse files

Former-commit-id: dfb192eb464f22cd65b9f947a889aecd1d8d1243
Former-commit-id: 4e9ba30f7f67980b897d1269e0395aec18418448

Test_photo/.DS_Store ADDED
Binary file (6.15 kB). View file
 
Test_photo/1.jpg ADDED
Test_photo/2.jpg ADDED
Test_photo/3.jpg ADDED
Test_photo/4.jpg ADDED
Test_photo/5.jpg ADDED
Test_photo/6.jpg ADDED
cat.csv ADDED
@@ -0,0 +1 @@
 
 
1
+ Мемориальная квартира Пушкина на Арбате,Новый Арбат,Памятник Александру Пушкину и Наталье Гончаровой,Памятники Булату Окуджаве в Москве,"Художественный (кинотеатр, Москва)",Центральный Дом актёра имени А. А. Яблочкиной
check_photo.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as transforms
3
+ from PIL import Image
4
+ import torchvision
5
+
6
+ # Запуск модели для распознания фото
7
+ def check_photo1(model, categorias, photo):
8
+ # Тот же формат фото, что и при обучении
9
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
10
+ std=[0.229, 0.224, 0.225])
11
+ preprocess = transforms.Compose([
12
+ transforms.Resize([70, 70]),
13
+ transforms.RandomHorizontalFlip(),
14
+ transforms.RandomAutocontrast(),
15
+ transforms.RandomEqualize(),
16
+ transforms.ToTensor(),
17
+ normalize
18
+ ])
19
+ batch = preprocess(photo).unsqueeze(0)
20
+ prediction = model(batch).squeeze(0).softmax(0)
21
+ class_id = prediction.argmax().item()
22
+ score = prediction[class_id].item()
23
+ return categorias[class_id], score
24
+
25
+
check_photo_model_init.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import csv
3
+
4
+ # Файл инициализации модели
5
+ def init_model():
6
+ # Загрузить модели из файла
7
+ pkl_filename = "pickle_model.pkl"
8
+ with open(pkl_filename, 'rb') as file:
9
+ model = pickle.load(file)
10
+
11
+ # Считывание категорий
12
+ file = open("cat.csv", "r")
13
+ cat1 = list(csv.reader(file, delimiter=","))
14
+ categorias = cat1[0]
15
+ file.close()
16
+ model.eval()
17
+ return model, categorias
check_photo_model_retrain.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import random
4
+ import numpy as np
5
+ import torchvision
6
+ import matplotlib.pyplot as plt
7
+ import torchvision.transforms as transforms
8
+ import shutil
9
+ import time
10
+ import xml.etree.ElementTree as et
11
+ import pickle
12
+ import csv
13
+
14
+ from tqdm import tqdm
15
+ from PIL import Image
16
+ from torchvision import models
17
+ from torch.utils.data import DataLoader
18
+ from torchvision.datasets import ImageFolder
19
+ # Размер одного пакета
20
+ BATCH_SIZE = 32
21
+
22
+ use_gpu = torch.cuda.is_available()
23
+ device = 'cuda' if use_gpu else 'cpu'
24
+ print('Connected device:', device)
25
+
26
+ # Датасет для тренировки
27
+ train_dataset = ImageFolder(
28
+ root='Data/Train'
29
+ )
30
+ # Датасет для проверки
31
+ valid_dataset = ImageFolder(
32
+ root='Data/Valid'
33
+ )
34
+
35
+ # augmentations (ухудшение качество чтобы не было переобучения)
36
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
37
+ std=[0.229, 0.224, 0.225])
38
+ train_dataset.transform = transforms.Compose([
39
+ transforms.Resize([70, 70]),
40
+ transforms.RandomHorizontalFlip(),
41
+ transforms.RandomAutocontrast(),
42
+ transforms.RandomEqualize(),
43
+ transforms.ToTensor(),
44
+ normalize
45
+ ])
46
+
47
+ valid_dataset.transform = transforms.Compose([
48
+ transforms.Resize([70, 70]),
49
+ transforms.ToTensor(),
50
+ normalize
51
+ ])
52
+
53
+ # Определение выборки для обучения
54
+ train_loader = DataLoader(
55
+ train_dataset, batch_size=BATCH_SIZE,
56
+ shuffle=True
57
+ )
58
+ # Определение выборки для проверки
59
+ valid_loader = DataLoader(
60
+ valid_dataset, batch_size=BATCH_SIZE,
61
+ shuffle=False
62
+ )
63
+
64
+ # Указание на используемую модель
65
+ def google(): # pretrained=True для tensorflow
66
+ model = models.googlenet(weights=models.GoogLeNet_Weights.IMAGENET1K_V1)
67
+ # Добавление линейного (выходного) слоя на основании которого идет дообучение
68
+ model.fc = torch.nn.Linear(1024, len(train_dataset.classes))
69
+ for param in model.parameters():
70
+ param.requires_grad = True
71
+ # Заморозка весов т.к. при переобучении модели они должны быть постоянны, а меняться будет только последний слой
72
+ model.inception3a.requires_grad = False
73
+ model.inception3b.requires_grad = False
74
+ model.inception4a.requires_grad = False
75
+ model.inception4b.requires_grad = False
76
+ model.inception4c.requires_grad = False
77
+ model.inception4d.requires_grad = False
78
+ model.inception4e.requires_grad = False
79
+ return model
80
+
81
+ # Функция обучения модели. Epoch - количество итераций обучения (прогонов по нейросети)
82
+ def train(model, optimizer, train_loader, val_loader, epoch=10):
83
+ loss_train, acc_train = [], []
84
+ loss_valid, acc_valid = [], []
85
+ # tqdm - прогресс бар
86
+ for epoch in tqdm(range(epoch)):
87
+ # Ошибки
88
+ losses, equals = [], []
89
+ torch.set_grad_enabled(True)
90
+
91
+ # Train. Обучение. В цикле проходится по картинкам и оптимизируются потери
92
+ model.train()
93
+ for i, (image, target) in enumerate(train_loader):
94
+ image = image.to(device)
95
+ target = target.to(device)
96
+ output = model(image)
97
+ loss = criterion(output,target)
98
+
99
+ losses.append(loss.item())
100
+ equals.extend(
101
+ [x.item() for x in torch.argmax(output, 1) == target])
102
+
103
+ optimizer.zero_grad()
104
+ loss.backward()
105
+ optimizer.step()
106
+ # Метрики отображающие резултитаты обучения модели
107
+ loss_train.append(np.mean(losses))
108
+ acc_train.append(np.mean(equals))
109
+ losses, equals = [], []
110
+ torch.set_grad_enabled(False)
111
+
112
+ # Validate. Оценка качества обучения
113
+ model.eval()
114
+ for i , (image, target) in enumerate(valid_loader):
115
+ image = image.to(device)
116
+ target = target.to(device)
117
+
118
+ output = model(image)
119
+ loss = criterion(output,target)
120
+
121
+ losses.append(loss.item())
122
+ equals.extend(
123
+ [y.item() for y in torch.argmax(output, 1) == target])
124
+
125
+ loss_valid.append(np.mean(losses))
126
+ acc_valid.append(np.mean(equals))
127
+
128
+ return loss_train, acc_train, loss_valid, acc_valid
129
+
130
+ criterion = torch.nn.CrossEntropyLoss()
131
+ criterion = criterion.to(device)
132
+
133
+ model = google()
134
+ print('Model: GoogLeNet\n')
135
+
136
+ # оптимайзер - отвечает за поиск и подбор оптимальных весов
137
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
138
+ model = model.to(device)
139
+
140
+ loss_train, acc_train, loss_valid, acc_valid = train(
141
+ model, optimizer, train_loader, valid_loader, 30)
142
+ print('acc_train:', acc_train, '\nacc_valid:', acc_valid)
143
+
144
+ # Сохранение модели в текущую рабочую директорию
145
+ pkl_filename = "pickle_model.pkl"
146
+ with open(pkl_filename, 'wb') as file:
147
+ pickle.dump(model, file)
148
+
149
+ # Категории. Получаются из имен папок
150
+ print(train_dataset.classes)
151
+ # Экспорт категорий в CSV
152
+ with open('cat.csv', 'w', newline='') as file:
153
+ writer = csv.writer(file)
154
+ writer.writerow(train_dataset.classes)
pickle_model.pkl.REMOVED.git-id CHANGED
@@ -1 +1 @@
1
- 7f6f655ec6c2e6c5e3909cb0b10f718b1e648be5
 
1
+ ad5ecffeba11c98262cbe84ad38397d2accd8892
test_check_photo.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from check_photo import *
2
+ from check_photo_model_init import *
3
+
4
+ model, cat = init_model()
5
+
6
+ Puskin_pamiatnik = Image.open("Data/Test_photo/1.jpg")
7
+ Nov_arbat1 = Image.open("Data/Test_photo/2.jpg")
8
+ Pushkin_dom1 = Image.open("Data/Test_photo/3.jpg")
9
+ CDA1 = Image.open("Data/Test_photo/4.jpg")
10
+ Okudjava1 = Image.open("Data/Test_photo/5.jpg")
11
+ Kinoteatr = Image.open("Data/Test_photo/6.jpg")
12
+ test_photos_dict = {'Puskin_pamiatnik':Puskin_pamiatnik,
13
+ 'Nov_arbat1':Nov_arbat1,
14
+ 'Pushkin_dom1': Pushkin_dom1,
15
+ 'CDA1': CDA1,
16
+ 'Okudjava1': Okudjava1,
17
+ 'Kinoteatr': Kinoteatr,
18
+ }
19
+ for name in test_photos_dict:
20
+ res_cat, res_score = check_photo1(model, cat, test_photos_dict[name])
21
+ print(f"{res_cat}: {100 * res_score:.1f}%", "right answer", name)