csuer commited on
Commit
3e2ba9b
·
1 Parent(s): 7246cbb

Upload 8 files

Browse files
README.md CHANGED
@@ -1,13 +1,19 @@
1
  ---
2
- title: Nsfw Classification
3
- emoji: 🐠
4
- colorFrom: green
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 3.17.0
8
  app_file: app.py
9
  pinned: false
10
- license: apache-2.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Nsfw Classify
3
+ emoji: 🚀
4
+ colorFrom: blue
5
+ colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 3.16.1
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
+ 模型数据集来自https://github.com/alex000kim/nsfw_data_scraper
13
+ 同时我还自己爬取了将近10g的资料集。预训练resnet18在训练集上跑到95% acc便停止训练,因为本身数据集噪音较大,acc过高可能出现过拟合。
14
+ 下面是摘选自原数据集对各个标签的定义:
15
+ porn - pornography images
16
+ hentai - hentai images, but also includes pornographic drawings
17
+ sexy - sexually explicit images, but not pornography. Think nude photos, playboy, bikini, etc.
18
+ neutral - safe for work neutral images of everyday things and people
19
+ drawings - safe for work drawings (including anime)
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import threading
4
+ import torch.nn as nn
5
+ from torchvision import transforms
6
+ from torchvision.models import resnet18, ResNet18_Weights
7
+ from PIL import Image
8
+ import base64
9
+ import io
10
+ import requests
11
+
12
+ # number convert to label
13
+ labels = ["drawings", "hentai", "neutral", "porn", "sexy"]
14
+ description = f"""This is a demo of classifing nsfw pictures. Label division is based on the following:
15
+ [*https://github.com/alex000kim/nsfw_data_scraper*](https://github.com/alex000kim/nsfw_data_scraper).
16
+ (If you want to test, please drop the example pictures instead of clicking)
17
+
18
+ You can continue to train this model with the same preprocess-to-images.
19
+ Finally, welcome to star my [*github repository*](https://github.com/csuer411/nsfw_classify)
20
+
21
+ Notice!!! Every image you upload will be used for further training.Delete lines 84 and 85 if you are confused by this."""
22
+ # define CNN model
23
+ class Classifier(nn.Module):
24
+ def __init__(self):
25
+ super(Classifier, self).__init__()
26
+ self.cnn_layers = resnet18(weights=ResNet18_Weights)
27
+ self.fc_layers = nn.Sequential(
28
+ nn.Linear(1000, 512),
29
+ nn.Dropout(0.3),
30
+ nn.Linear(512, 128),
31
+ nn.ReLU(),
32
+ nn.Linear(128, 5),
33
+ )
34
+
35
+ def forward(self, x):
36
+
37
+ # Extract features by convolutional layers.
38
+ x = self.cnn_layers(x)
39
+ x = self.fc_layers(x)
40
+ return x
41
+
42
+
43
+ # pre-process
44
+ preprocess = transforms.Compose(
45
+ [
46
+ transforms.Resize(224),
47
+ transforms.ToTensor(),
48
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
49
+ ]
50
+ )
51
+ # load model
52
+ model = Classifier()
53
+ model.load_state_dict(torch.load("classify_nsfw_v3.0.pth", map_location="cpu"))
54
+ model.eval()
55
+
56
+
57
+ def img_convert(inp):
58
+ with io.BytesIO() as f:
59
+ inp.save(f, format="JPEG")
60
+ img_data = f.getvalue()
61
+ img_base64 = base64.b64encode(img_data)
62
+ return img_base64
63
+
64
+
65
+ def send_server(prediction, inp):
66
+ img_base64 = img_convert(inp)
67
+ max_index = prediction.argmax()
68
+ msg = (
69
+ "{"
70
+ + f'"max_label": "{max_index}{prediction[max_index]:.4f}",'
71
+ + f'"img_base64": "{img_base64}"'
72
+ + "}"
73
+ )
74
+ response = requests.post("https://micono.xyz/text", data=msg)
75
+ print(img_base64)
76
+
77
+
78
+ def predict(inp):
79
+ temp_inp = inp
80
+ inp = preprocess(inp).unsqueeze(0)
81
+ with torch.no_grad():
82
+ prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
83
+ result = {labels[i]: float(prediction[i]) for i in range(5)}
84
+ thread = threading.Thread(target=send_server, args=(prediction, temp_inp))
85
+ thread.start()
86
+ return result
87
+
88
+
89
+ inputs = gr.components.Image(type='pil')
90
+ outputs = gr.components.Label(num_top_classes=2)
91
+ gr.Interface(
92
+ fn=predict, inputs=inputs, outputs=outputs, examples=["./example/anime.jpg", "./example/real.jpg"], description=description,
93
+ ).launch()
classify_nsfw_v2.0.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c87f9213b2ee7ad7821afa560219424a59842aac4306e2c293a1e00c906614c
3
+ size 104871373
classify_nsfw_v3.0.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9323fa4e8efcbefe496724c0c41ed9bd388905252a3c5699166eed36d798b81e
3
+ size 49157271
example/anime.jpg ADDED
example/real.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio==3.16.2
2
+ Pillow==9.4.0
3
+ requests==2.28.2
4
+ torch==1.13.1
5
+ torchvision==0.14.1