Muinez commited on
Commit
1d56378
·
verified ·
1 Parent(s): 69791ae

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +56 -30
  2. dbimutils.py +2 -2
app.py CHANGED
@@ -1,31 +1,57 @@
1
- import gradio as gr
2
- import torch
3
- from transformers import AutoImageProcessor, ConvNextV2ForImageClassification
4
- from transformers import AutoModelForImageClassification
5
- from torch import nn
6
- import dbimutils as utils
7
-
8
- DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
9
-
10
- image_processor = AutoImageProcessor.from_pretrained("Muinez/artwork-scorer")
11
- model = AutoModelForImageClassification.from_pretrained("Muinez/artwork-scorer", problem_type="multi_label_classification").to(DEVICE)
12
-
13
- def predict(img):
14
- file = utils.preprocess_image(img)
15
- encoded = image_processor(file, return_tensors="pt").to(DEVICE)
16
-
17
- with torch.no_grad():
18
- logits = model(**encoded).logits.cpu()
19
-
20
- outputs = nn.functional.sigmoid(logits)
21
-
22
- return outputs[0][0].item(), outputs[0][1].item(), outputs[0][2].item()
23
-
24
- gr.Interface(
25
- title="Artwork scorer",
26
- description="Predicts score (0-1) for artwork.\nCould be wrong!!!\nDoes not work very well with nsfw i.e. it was not trained on it",
27
- fn=predict,
28
- allow_flagging="never",
29
- inputs=gr.Image(type="pil"),
30
- outputs=[gr.Number(label="Score"), gr.Number(label="View count ratio (probably useless)"), gr.Number(label="Upload date 0 - 2016, 1 - 2023")]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  ).launch()
 
1
+ import gradio as gr
2
+ import torch
3
+ from torch import nn
4
+ from transformers import SiglipImageProcessor,SiglipModel
5
+ import dbimutils as utils
6
+
7
+ class ScoreClassifier(nn.Module):
8
+ def __init__(self):
9
+ super(ScoreClassifier, self).__init__()
10
+
11
+ self.classifier = nn.Sequential(
12
+ nn.Linear(256, 1),
13
+ nn.Sigmoid()
14
+ )
15
+
16
+ self.extractor = nn.Sequential(
17
+ nn.Linear(768, 512),
18
+ nn.BatchNorm1d(512),
19
+ nn.ReLU(),
20
+ nn.Linear(512, 256),
21
+ nn.BatchNorm1d(256),
22
+ nn.ReLU(),
23
+ nn.Linear(256, 256),
24
+ nn.ReLU(),
25
+ )
26
+
27
+ def forward(self, img):
28
+ return self.classifier(self.extractor(img))
29
+
30
+ from huggingface_hub import hf_hub_download
31
+ model_file = hf_hub_download(repo_id="Muinez/Image-scorer", filename="scorer.pth")
32
+
33
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
34
+ model = ScoreClassifier().to(DEVICE)
35
+ model.load_state_dict(torch.load("scorer.pth"))
36
+ model.eval()
37
+
38
+ processor = SiglipImageProcessor.from_pretrained('google/siglip-base-patch16-512')
39
+ siglip = SiglipModel.from_pretrained('google/siglip-base-patch16-512').to(DEVICE)
40
+
41
+ def predict(img):
42
+ img = utils.preprocess_image(img)
43
+ encoded = processor(img, return_tensors="pt").pixel_values.to(DEVICE)
44
+
45
+ with torch.no_grad():
46
+ score = model(siglip.get_image_features(encoded))
47
+
48
+ return score.item()
49
+
50
+ gr.Interface(
51
+ title="Artwork scorer",
52
+ description="Predicts score (0-1) for artwork.\nCould be wrong!!!\nDoes not work very well with nsfw i.e. it was not trained on it",
53
+ fn=predict,
54
+ allow_flagging="never",
55
+ inputs=gr.Image(type="pil"),
56
+ outputs=[gr.Number(label="Score")]
57
  ).launch()
dbimutils.py CHANGED
@@ -61,8 +61,8 @@ def preprocess_image(img):
61
  image = new_image.convert('RGB')
62
  image = np.asarray(image)
63
 
64
- image = make_square(image, 384)
65
- image = smart_resize(image, 384)
66
  image = image.astype(np.float32)
67
 
68
  return Image.fromarray(np.uint8(image))
 
61
  image = new_image.convert('RGB')
62
  image = np.asarray(image)
63
 
64
+ image = make_square(image, 512)
65
+ image = smart_resize(image, 512)
66
  image = image.astype(np.float32)
67
 
68
  return Image.fromarray(np.uint8(image))