import gradio as gr import torch import torch.nn as nn from torchvision import transforms from torchvision.models import resnet18, ResNet18_Weights from PIL import Image # 配置参数 labels = ["Drawings", "Hentai", "Neutral", "Porn", "Sexy"] theme_color = "#6C5B7B" # 主色调改为优雅的紫色 description = """

🎨 NSFW 图片分类器

该模型使用深度神经网络对图片内容进行分类,支持以下类别:

🖼️ 请上传图片或点击下方示例体验

""".format(color=theme_color) # 模型定义和预处理(保持不变) # ... [保持原有模型代码不变] ... # 高级 CSS 样式 advanced_css = f""" .gradio-container {{ background: linear-gradient(135deg, #f8f9fa 0%, #e9ecef 100%); min-height: 100vh; }} .header-section {{ background: white; padding: 2rem; border-radius: 15px; box-shadow: 0 4px 6px rgba(0,0,0,0.05); margin-bottom: 2rem; }} .result-card {{ background: white !important; padding: 1.5rem !important; border-radius: 12px !important; box-shadow: 0 2px 8px rgba(108,91,123,0.1) !important; }} .custom-button {{ background: {theme_color} !important; color: white !important; border: none !important; padding: 12px 28px !important; border-radius: 25px !important; transition: all 0.3s ease !important; }} .custom-button:hover {{ transform: translateY(-2px); box-shadow: 0 4px 12px rgba(108,91,123,0.3) !important; }} .upload-box {{ border: 2px dashed {theme_color} !important; border-radius: 15px !important; background: rgba(255,255,255,0.9) !important; }} .example-card {{ cursor: pointer; transition: all 0.3s ease; border-radius: 12px; overflow: hidden; }} .example-card:hover {{ transform: scale(1.02); box-shadow: 0 4px 12px rgba(108,91,123,0.2); }} .prob-bar {{ height: 8px; border-radius: 4px; background: linear-gradient(90deg, {theme_color} 0%, #C06C84 100%); }} """ # Define CNN model class Classifier(nn.Module): def __init__(self): super(Classifier, self).__init__() self.cnn_layers = resnet18(weights=ResNet18_Weights.DEFAULT) self.fc_layers = nn.Sequential( nn.Linear(1000, 512), nn.Dropout(0.3), nn.Linear(512, 128), nn.ReLU(), nn.Linear(128, 5), ) def forward(self, x): x = self.cnn_layers(x) x = self.fc_layers(x) return x # Pre-process preprocess = transforms.Compose([ transforms.Resize(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Load model model = Classifier() model.load_state_dict(torch.load("classify_nsfw_v3.0.pth", map_location="cpu")) model.eval() def predict(image_path): img = Image.open(image_path).convert("RGB") img = preprocess(img).unsqueeze(0) with torch.no_grad(): prediction = torch.nn.functional.softmax(model(img)[0], dim=0) result = {labels[i]: float(prediction[i]) for i in range(5)} return result with gr.Blocks(theme=gr.themes.Soft(), css=advanced_css) as demo: # 标题区 with gr.Column(elem_classes="header-section"): gr.Markdown("# 🎭 智能内容识别系统", elem_id="main-title") gr.HTML(description) # 主功能区 with gr.Row(): # 输入列 with gr.Column(scale=2): upload_box = gr.Image( type="filepath", label="📤 上传图片", elem_id="upload-box", elem_classes="upload-box", height=400 ) with gr.Row(): submit_btn = gr.Button( "✨ 开始分析", elem_classes="custom-button", size="lg" ) clear_btn = gr.Button( "🔄 重新上传", variant="secondary", size="lg" ) # 输出列 with gr.Column(scale=1): with gr.Column(elem_classes="result-card"): gr.Markdown("### 🔍 分析结果") result_display = gr.Label( label="分类概率分布", num_top_classes=3, show_label=False ) gr.Markdown("**最高概率类别**: ", elem_id="dynamic-text") # 示例区 with gr.Column(): gr.Markdown("### 🖼️ 示例图片") examples = gr.Examples( examples=["./example/anime.jpg", "./example/real.jpg"], inputs=upload_box, examples_per_page=2, label="点击使用示例", elem_id="example-gallery" ) # 交互逻辑 clear_btn.click(fn=lambda: None, inputs=None, outputs=upload_box) submit_btn.click( fn=predict, inputs=upload_box, outputs=result_display, api_name="predict" ) # 启动界面 demo.launch()