Thouph commited on
Commit
6f0e0c9
1 Parent(s): 9ed37aa

Upload 2 files

Browse files
Files changed (2) hide show
  1. inference.py +42 -0
  2. model.pth +3 -0
inference.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
+ from PIL import Image
4
+ import torch
5
+ from torchvision.transforms import transforms
6
+
7
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8
+ model = torch.load("path/to/your/model.pth")
9
+ model.to(device)
10
+ model.eval()
11
+
12
+ transform = transforms.Compose([
13
+ transforms.ToTensor(),
14
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
15
+ ])
16
+
17
+ with open("tags_8034.json", "r") as f:
18
+ tags = json.load(f)
19
+ tags = sorted(tags)
20
+ tags.append("placeholder0")
21
+ tags.append("explicit")
22
+ tags.append("questionable")
23
+ tags.append("safe")
24
+
25
+ image_path = "path/to/your/image.jpg"
26
+ start = time.time()
27
+ img = Image.open(image_path).convert('RGB')
28
+ img.thumbnail((448, 448), Image.LANCZOS)
29
+ tensor = transform(img).unsqueeze(0).to(device)
30
+ with torch.no_grad():
31
+ out = model(tensor)
32
+ probabilities = torch.nn.functional.sigmoid(out[0])
33
+
34
+ indices = torch.where(probabilities > 0.3)[0]
35
+ values = probabilities[indices]
36
+
37
+ for i in range(indices.size(0)):
38
+ print(tags[indices[i]], values[i].item())
39
+
40
+ end = time.time()
41
+ print(f'Executed in {end - start} seconds')
42
+ print("\n\n", end="")
model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0304c62d8f3e79834122711907f0bda37a415ffee9692665791c4b1ccd36d8d7
3
+ size 254300130