Kaori1707 commited on
Commit
541eb9d
โ€ข
1 Parent(s): 69ef5c2

Update classification.py

Browse files
Files changed (1) hide show
  1. classification.py +85 -85
classification.py CHANGED
@@ -1,86 +1,86 @@
1
- import math
2
- import torch
3
- from timm import create_model
4
- from torchvision import transforms
5
- from PIL import Image
6
-
7
- # ๊ธฐ๋ณธ ํฌ๋กญ ๋น„์œจ
8
- DEFAULT_CROP_PCT = 0.875
9
-
10
- # EfficientNet-B0 ๋ชจ๋ธ ์„ค์ • ๋ฐ ๋กœ๋“œ
11
- weights_path = "./weights/resnest101e.in1k_weight_Pear_classification.pt" # ๋กœ์ปฌ ๊ฐ€์ค‘์น˜ ํŒŒ์ผ ๊ฒฝ๋กœ
12
- model_name = "resnest101e"
13
- model = create_model(model_name, pretrained=False,num_classes=9) # ์‚ฌ์ „ ํ•™์Šต ๋กœ๋“œ ์ƒ๋žต
14
- #model.classifier = torch.nn.Linear(model.classifier.in_features, 2) # ์ด์ง„ ๋ถ„๋ฅ˜๋กœ ์ˆ˜์ •
15
- model.load_state_dict(torch.load(weights_path)) # ๋กœ์ปฌ ๊ฐ€์ค‘์น˜ ๋กœ๋“œ
16
- model.eval() # ํ‰๊ฐ€ ๋ชจ๋“œ ์„ค์ •
17
-
18
- # ํด๋ž˜์Šค ์ด๋ฆ„ ์ˆ˜๋™ ์ง€์ •
19
- # class_labels = ["Abormal Pear", "Normal Pear"] # ๋ฐ์ดํ„ฐ์…‹์— ๋งž๊ฒŒ ์ˆ˜์ •
20
- class_labels = ["์ •์ƒ", "ํ‘์„ฑ๋ณ‘","๊ณผํ”ผ์–ผ๋ฃฉ","๋ณต์ˆญ์•„ ์ˆœ๋‚˜๋ฐฉ","๋ณต์ˆญ์•„ ์‚ผ์‹๋‚˜๋ฐฉ","๋ฐฐ ๊น์ง€๋ฒŒ๋ ˆ","์žŽ๋ง์ด ๋‚˜๋ฐฉ๋ฅ˜", "๊ธฐํƒ€","๊ณผํ”ผํ‘๋ณ€"]
21
-
22
- # ์ „์ฒ˜๋ฆฌ ํ•จ์ˆ˜
23
- def transforms_imagenet_eval(
24
- img_path: str,
25
- img_size: int = 224,
26
- crop_pct: float = DEFAULT_CROP_PCT,
27
- mean: tuple = (0.485, 0.456, 0.406),
28
- std: tuple = (0.229, 0.224, 0.225),
29
- normalize: bool = True,
30
- ):
31
- """
32
- ImageNet ์Šคํƒ€์ผ ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ ํ•จ์ˆ˜.
33
-
34
- Args:
35
- img_path (str): ์ด๋ฏธ์ง€ ๊ฒฝ๋กœ.
36
- img_size (int): ํฌ๋กญ ํฌ๊ธฐ.
37
- crop_pct (float): ํฌ๋กญ ๋น„์œจ.
38
- mean (tuple): ์ •๊ทœํ™” ํ‰๊ท .
39
- std (tuple): ์ •๊ทœํ™” ํ‘œ์ค€ํŽธ์ฐจ.
40
- normalize (bool): ์ •๊ทœํ™” ์—ฌ๋ถ€.
41
-
42
- Returns:
43
- torch.Tensor: ์ „์ฒ˜๋ฆฌ๋œ ์ด๋ฏธ์ง€ ํ…์„œ.
44
- """
45
- img = Image.open(img_path).convert("RGB") # ์ด๋ฏธ์ง€ ๋กœ๋“œ ๋ฐ RGB ๋ณ€ํ™˜
46
- scale_size = math.floor(img_size / crop_pct) # ๋ฆฌ์‚ฌ์ด์ฆˆ ํฌ๊ธฐ ๊ณ„์‚ฐ
47
-
48
- # Transform ์„ค์ •
49
- tfl = [
50
- transforms.Resize((scale_size, scale_size), interpolation=transforms.InterpolationMode.BILINEAR),
51
- transforms.CenterCrop(img_size),
52
- transforms.ToTensor(),
53
- ]
54
- if normalize:
55
- tfl += [transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std))]
56
-
57
- transform = transforms.Compose(tfl)
58
- return transform(img)
59
-
60
- # ์ถ”๋ก  ํ•จ์ˆ˜
61
- def predict(image_path: str):
62
- """
63
- ์ฃผ์–ด์ง„ ์ด๋ฏธ์ง€ ํŒŒ์ผ ๊ฒฝ๋กœ๋ฅผ ๋ฐ›์•„ ๋ชจ๋ธ ์ถ”๋ก ์„ ์ˆ˜ํ–‰.
64
-
65
- Args:
66
- image_path (str): ์ž…๋ ฅ ์ด๋ฏธ์ง€ ๊ฒฝ๋กœ.
67
-
68
- Returns:
69
- str: ๋ชจ๋ธ ์˜ˆ์ธก ๊ฒฐ๊ณผ.
70
- """
71
- # ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ
72
- input_tensor = transforms_imagenet_eval(
73
- img_path=image_path,
74
- img_size=224,
75
- normalize=True
76
- ).unsqueeze(0) # ๋ฐฐ์น˜ ์ฐจ์› ์ถ”๊ฐ€
77
-
78
- # ๋ชจ๋ธ ์ถ”๋ก 
79
- with torch.no_grad():
80
- prediction = model(input_tensor)
81
-
82
- probs = torch.nn.functional.softmax(prediction[0], dim=-1)
83
- confidences = {class_labels[i]: float(probs[i]) for i in range(9)}
84
-
85
- # ์˜ˆ์ธก ๊ฒฐ๊ณผ ๋ฐ˜ํ™˜
86
  return confidences
 
1
+ import math
2
+ import torch
3
+ from timm import create_model
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+
7
+ # ๊ธฐ๋ณธ ํฌ๋กญ ๋น„์œจ
8
+ DEFAULT_CROP_PCT = 0.875
9
+
10
+ # EfficientNet-B0 ๋ชจ๋ธ ์„ค์ • ๋ฐ ๋กœ๋“œ
11
+ weights_path = "./weights/resnest101e.in1k_weight_Pear_classification.pt" # ๋กœ์ปฌ ๊ฐ€์ค‘์น˜ ํŒŒ์ผ ๊ฒฝ๋กœ
12
+ model_name = "resnest101e"
13
+ model = create_model(model_name, pretrained=False,num_classes=9) # ์‚ฌ์ „ ํ•™์Šต ๋กœ๋“œ ์ƒ๋žต
14
+ #model.classifier = torch.nn.Linear(model.classifier.in_features, 2) # ์ด์ง„ ๋ถ„๋ฅ˜๋กœ ์ˆ˜์ •
15
+ model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu'))) # ๋กœ์ปฌ ๊ฐ€์ค‘์น˜ ๋กœ๋“œ
16
+ model.eval() # ํ‰๊ฐ€ ๋ชจ๋“œ ์„ค์ •
17
+
18
+ # ํด๋ž˜์Šค ์ด๋ฆ„ ์ˆ˜๋™ ์ง€์ •
19
+ # class_labels = ["Abormal Pear", "Normal Pear"] # ๋ฐ์ดํ„ฐ์…‹์— ๋งž๊ฒŒ ์ˆ˜์ •
20
+ class_labels = ["์ •์ƒ", "ํ‘์„ฑ๋ณ‘","๊ณผํ”ผ์–ผ๋ฃฉ","๋ณต์ˆญ์•„ ์ˆœ๋‚˜๋ฐฉ","๋ณต์ˆญ์•„ ์‚ผ์‹๋‚˜๋ฐฉ","๋ฐฐ ๊น์ง€๋ฒŒ๋ ˆ","์žŽ๋ง์ด ๋‚˜๋ฐฉ๋ฅ˜", "๊ธฐํƒ€","๊ณผํ”ผํ‘๋ณ€"]
21
+
22
+ # ์ „์ฒ˜๋ฆฌ ํ•จ์ˆ˜
23
+ def transforms_imagenet_eval(
24
+ img_path: str,
25
+ img_size: int = 224,
26
+ crop_pct: float = DEFAULT_CROP_PCT,
27
+ mean: tuple = (0.485, 0.456, 0.406),
28
+ std: tuple = (0.229, 0.224, 0.225),
29
+ normalize: bool = True,
30
+ ):
31
+ """
32
+ ImageNet ์Šคํƒ€์ผ ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ ํ•จ์ˆ˜.
33
+
34
+ Args:
35
+ img_path (str): ์ด๋ฏธ์ง€ ๊ฒฝ๋กœ.
36
+ img_size (int): ํฌ๋กญ ํฌ๊ธฐ.
37
+ crop_pct (float): ํฌ๋กญ ๋น„์œจ.
38
+ mean (tuple): ์ •๊ทœํ™” ํ‰๊ท .
39
+ std (tuple): ์ •๊ทœํ™” ํ‘œ์ค€ํŽธ์ฐจ.
40
+ normalize (bool): ์ •๊ทœํ™” ์—ฌ๋ถ€.
41
+
42
+ Returns:
43
+ torch.Tensor: ์ „์ฒ˜๋ฆฌ๋œ ์ด๋ฏธ์ง€ ํ…์„œ.
44
+ """
45
+ img = Image.open(img_path).convert("RGB") # ์ด๋ฏธ์ง€ ๋กœ๋“œ ๋ฐ RGB ๋ณ€ํ™˜
46
+ scale_size = math.floor(img_size / crop_pct) # ๋ฆฌ์‚ฌ์ด์ฆˆ ํฌ๊ธฐ ๊ณ„์‚ฐ
47
+
48
+ # Transform ์„ค์ •
49
+ tfl = [
50
+ transforms.Resize((scale_size, scale_size), interpolation=transforms.InterpolationMode.BILINEAR),
51
+ transforms.CenterCrop(img_size),
52
+ transforms.ToTensor(),
53
+ ]
54
+ if normalize:
55
+ tfl += [transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std))]
56
+
57
+ transform = transforms.Compose(tfl)
58
+ return transform(img)
59
+
60
+ # ์ถ”๋ก  ํ•จ์ˆ˜
61
+ def predict(image_path: str):
62
+ """
63
+ ์ฃผ์–ด์ง„ ์ด๋ฏธ์ง€ ํŒŒ์ผ ๊ฒฝ๋กœ๋ฅผ ๋ฐ›์•„ ๋ชจ๋ธ ์ถ”๋ก ์„ ์ˆ˜ํ–‰.
64
+
65
+ Args:
66
+ image_path (str): ์ž…๋ ฅ ์ด๋ฏธ์ง€ ๊ฒฝ๋กœ.
67
+
68
+ Returns:
69
+ str: ๋ชจ๋ธ ์˜ˆ์ธก ๊ฒฐ๊ณผ.
70
+ """
71
+ # ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ
72
+ input_tensor = transforms_imagenet_eval(
73
+ img_path=image_path,
74
+ img_size=224,
75
+ normalize=True
76
+ ).unsqueeze(0) # ๋ฐฐ์น˜ ์ฐจ์› ์ถ”๊ฐ€
77
+
78
+ # ๋ชจ๋ธ ์ถ”๋ก 
79
+ with torch.no_grad():
80
+ prediction = model(input_tensor)
81
+
82
+ probs = torch.nn.functional.softmax(prediction[0], dim=-1)
83
+ confidences = {class_labels[i]: float(probs[i]) for i in range(9)}
84
+
85
+ # ์˜ˆ์ธก ๊ฒฐ๊ณผ ๋ฐ˜ํ™˜
86
  return confidences