Spaces:
Runtime error
Runtime error
import torch | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
import io | |
import re | |
import os | |
from datetime import datetime | |
import spaces | |
def load_model(): | |
model_id = "oshizo/japanese-sexual-moderation-v2" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForSequenceClassification.from_pretrained( | |
model_id, | |
problem_type="regression" | |
) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = model.to(device) | |
return tokenizer, model, device | |
def analyze_text(text, tokenizer, model, device): | |
with torch.no_grad(): | |
encoding = tokenizer([text], padding='max_length', truncation=True, max_length=64, return_tensors="pt") | |
encoding = {k: v.to(device) for k, v in encoding.items()} | |
score = model(**encoding).logits.item() | |
return score | |
def split_text(text, split_by='sentence'): | |
if split_by == 'sentence': | |
return [sent.strip() for sent in re.split('。|!|?', text) if sent.strip()] | |
else: # split by line | |
return [line.strip() for line in text.split('\n') if line.strip()] | |
def create_graph(texts, scores): | |
fig, ax = plt.subplots(figsize=(12, 6)) | |
ax.bar(range(len(scores)), scores) | |
ax.set_xlabel('テキスト番号') | |
ax.set_ylabel('スコア') | |
ax.set_title("分析結果") | |
ax.set_xticks(range(len(scores))) | |
ax.set_xticklabels(range(1, len(scores) + 1)) | |
plt.tight_layout() | |
return fig | |
def create_not_r18_text(texts, scores): | |
not_r18_texts = [] | |
for text, score in zip(texts, scores): | |
if score < 0.4: | |
not_r18_texts.append(text) | |
else: | |
not_r18_texts.append('') # 除外された行の位置に空行を挿入 | |
return '\n'.join(not_r18_texts) | |
tokenizer, model, device = load_model() | |
def process_text(text, split_by): | |
texts = split_text(text, split_by) | |
scores = [analyze_text(t, tokenizer, model, device) for t in texts] | |
graph = create_graph(texts, scores) | |
not_r18_text = create_not_r18_text(texts, scores) | |
result = { | |
"texts": texts, | |
"scores": scores, | |
} | |
return result, graph, not_r18_text | |
# Gradio インターフェースの定義 | |
iface = gr.Interface( | |
fn=process_text, | |
inputs=[ | |
gr.Textbox(label="テキスト入力"), | |
gr.Radio(["sentence", "line"], label="分割方法", value="sentence") | |
], | |
outputs=[ | |
gr.JSON(label="分析結果"), | |
gr.Plot(label="スコアグラフ"), | |
gr.Textbox(label="R18判定除外テキスト") | |
], | |
title="テキスト分析API", | |
description="テキストを入力し、R18判定と分析を行います。" | |
) | |
# サーバーの起動 | |
if __name__ == "__main__": | |
iface.launch(server_name="0.0.0.0", server_port=7860) |