import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import resnet18, ResNet18_Weights
from PIL import Image
# 配置参数
labels = ["Drawings", "Hentai", "Neutral", "Porn", "Sexy"]
theme_color = "#6C5B7B" # 主色调改为优雅的紫色
description = """
🎨 NSFW 图片分类器
该模型使用深度神经网络对图片内容进行分类,支持以下类别:
- Drawings - 艺术绘画作品
- Hentai - 二次元成人内容
- Neutral - 日常安全内容
- Porn - 露骨成人内容
- Sexy - 性感但不露骨内容
🖼️ 请上传图片或点击下方示例体验
""".format(color=theme_color)
# 模型定义和预处理(保持不变)
# ... [保持原有模型代码不变] ...
# 高级 CSS 样式
advanced_css = f"""
.gradio-container {{
background: linear-gradient(135deg, #f8f9fa 0%, #e9ecef 100%);
min-height: 100vh;
}}
.header-section {{
background: white;
padding: 2rem;
border-radius: 15px;
box-shadow: 0 4px 6px rgba(0,0,0,0.05);
margin-bottom: 2rem;
}}
.result-card {{
background: white !important;
padding: 1.5rem !important;
border-radius: 12px !important;
box-shadow: 0 2px 8px rgba(108,91,123,0.1) !important;
}}
.custom-button {{
background: {theme_color} !important;
color: white !important;
border: none !important;
padding: 12px 28px !important;
border-radius: 25px !important;
transition: all 0.3s ease !important;
}}
.custom-button:hover {{
transform: translateY(-2px);
box-shadow: 0 4px 12px rgba(108,91,123,0.3) !important;
}}
.upload-box {{
border: 2px dashed {theme_color} !important;
border-radius: 15px !important;
background: rgba(255,255,255,0.9) !important;
}}
.example-card {{
cursor: pointer;
transition: all 0.3s ease;
border-radius: 12px;
overflow: hidden;
}}
.example-card:hover {{
transform: scale(1.02);
box-shadow: 0 4px 12px rgba(108,91,123,0.2);
}}
.prob-bar {{
height: 8px;
border-radius: 4px;
background: linear-gradient(90deg, {theme_color} 0%, #C06C84 100%);
}}
"""
# Define CNN model
class Classifier(nn.Module):
def __init__(self):
super(Classifier, self).__init__()
self.cnn_layers = resnet18(weights=ResNet18_Weights.DEFAULT)
self.fc_layers = nn.Sequential(
nn.Linear(1000, 512),
nn.Dropout(0.3),
nn.Linear(512, 128),
nn.ReLU(),
nn.Linear(128, 5),
)
def forward(self, x):
x = self.cnn_layers(x)
x = self.fc_layers(x)
return x
# Pre-process
preprocess = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Load model
model = Classifier()
model.load_state_dict(torch.load("classify_nsfw_v3.0.pth", map_location="cpu"))
model.eval()
def predict(image_path):
img = Image.open(image_path).convert("RGB")
img = preprocess(img).unsqueeze(0)
with torch.no_grad():
prediction = torch.nn.functional.softmax(model(img)[0], dim=0)
result = {labels[i]: float(prediction[i]) for i in range(5)}
return result
with gr.Blocks(theme=gr.themes.Soft(), css=advanced_css) as demo:
# 标题区
with gr.Column(elem_classes="header-section"):
gr.Markdown("# 🎭 智能内容识别系统", elem_id="main-title")
gr.HTML(description)
# 主功能区
with gr.Row():
# 输入列
with gr.Column(scale=2):
upload_box = gr.Image(
type="filepath",
label="📤 上传图片",
elem_id="upload-box",
elem_classes="upload-box",
height=400
)
with gr.Row():
submit_btn = gr.Button(
"✨ 开始分析",
elem_classes="custom-button",
size="lg"
)
clear_btn = gr.Button(
"🔄 重新上传",
variant="secondary",
size="lg"
)
# 输出列
with gr.Column(scale=1):
with gr.Column(elem_classes="result-card"):
gr.Markdown("### 🔍 分析结果")
result_display = gr.Label(
label="分类概率分布",
num_top_classes=3,
show_label=False
)
gr.Markdown("**最高概率类别**: ", elem_id="dynamic-text")
# 示例区
with gr.Column():
gr.Markdown("### 🖼️ 示例图片")
examples = gr.Examples(
examples=["./example/anime.jpg", "./example/real.jpg"],
inputs=upload_box,
examples_per_page=2,
label="点击使用示例",
elem_id="example-gallery"
)
# 交互逻辑
clear_btn.click(fn=lambda: None, inputs=None, outputs=upload_box)
submit_btn.click(
fn=predict,
inputs=upload_box,
outputs=result_display,
api_name="predict"
)
# 启动界面
demo.launch()