xbarusui's picture
first edit app.py
1d87d9c verified
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import gradio as gr
import matplotlib.pyplot as plt
import io
import re
import os
from datetime import datetime
import spaces
@spaces.GPU
def load_model():
model_id = "oshizo/japanese-sexual-moderation-v2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(
model_id,
problem_type="regression"
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
return tokenizer, model, device
@spaces.GPU
def analyze_text(text, tokenizer, model, device):
with torch.no_grad():
encoding = tokenizer([text], padding='max_length', truncation=True, max_length=64, return_tensors="pt")
encoding = {k: v.to(device) for k, v in encoding.items()}
score = model(**encoding).logits.item()
return score
@spaces.GPU
def split_text(text, split_by='sentence'):
if split_by == 'sentence':
return [sent.strip() for sent in re.split('。|!|?', text) if sent.strip()]
else: # split by line
return [line.strip() for line in text.split('\n') if line.strip()]
@spaces.GPU
def create_graph(texts, scores):
fig, ax = plt.subplots(figsize=(12, 6))
ax.bar(range(len(scores)), scores)
ax.set_xlabel('テキスト番号')
ax.set_ylabel('スコア')
ax.set_title("分析結果")
ax.set_xticks(range(len(scores)))
ax.set_xticklabels(range(1, len(scores) + 1))
plt.tight_layout()
return fig
@spaces.GPU
def create_not_r18_text(texts, scores):
not_r18_texts = []
for text, score in zip(texts, scores):
if score < 0.4:
not_r18_texts.append(text)
else:
not_r18_texts.append('') # 除外された行の位置に空行を挿入
return '\n'.join(not_r18_texts)
tokenizer, model, device = load_model()
@spaces.GPU
def process_text(text, split_by):
texts = split_text(text, split_by)
scores = [analyze_text(t, tokenizer, model, device) for t in texts]
graph = create_graph(texts, scores)
not_r18_text = create_not_r18_text(texts, scores)
result = {
"texts": texts,
"scores": scores,
}
return result, graph, not_r18_text
# Gradio インターフェースの定義
iface = gr.Interface(
fn=process_text,
inputs=[
gr.Textbox(label="テキスト入力"),
gr.Radio(["sentence", "line"], label="分割方法", value="sentence")
],
outputs=[
gr.JSON(label="分析結果"),
gr.Plot(label="スコアグラフ"),
gr.Textbox(label="R18判定除外テキスト")
],
title="テキスト分析API",
description="テキストを入力し、R18判定と分析を行います。"
)
# サーバーの起動
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860)