import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer import gradio as gr import matplotlib.pyplot as plt import io import re import os from datetime import datetime import spaces @spaces.GPU def load_model(): model_id = "oshizo/japanese-sexual-moderation-v2" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForSequenceClassification.from_pretrained( model_id, problem_type="regression" ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) return tokenizer, model, device @spaces.GPU def analyze_text(text, tokenizer, model, device): with torch.no_grad(): encoding = tokenizer([text], padding='max_length', truncation=True, max_length=64, return_tensors="pt") encoding = {k: v.to(device) for k, v in encoding.items()} score = model(**encoding).logits.item() return score @spaces.GPU def split_text(text, split_by='sentence'): if split_by == 'sentence': return [sent.strip() for sent in re.split('。|!|?', text) if sent.strip()] else: # split by line return [line.strip() for line in text.split('\n') if line.strip()] @spaces.GPU def create_graph(texts, scores): fig, ax = plt.subplots(figsize=(12, 6)) ax.bar(range(len(scores)), scores) ax.set_xlabel('テキスト番号') ax.set_ylabel('スコア') ax.set_title("分析結果") ax.set_xticks(range(len(scores))) ax.set_xticklabels(range(1, len(scores) + 1)) plt.tight_layout() return fig @spaces.GPU def create_not_r18_text(texts, scores): not_r18_texts = [] for text, score in zip(texts, scores): if score < 0.4: not_r18_texts.append(text) else: not_r18_texts.append('') # 除外された行の位置に空行を挿入 return '\n'.join(not_r18_texts) tokenizer, model, device = load_model() @spaces.GPU def process_text(text, split_by): texts = split_text(text, split_by) scores = [analyze_text(t, tokenizer, model, device) for t in texts] graph = create_graph(texts, scores) not_r18_text = create_not_r18_text(texts, scores) result = { "texts": texts, "scores": scores, } return result, graph, not_r18_text # Gradio インターフェースの定義 iface = gr.Interface( fn=process_text, inputs=[ gr.Textbox(label="テキスト入力"), gr.Radio(["sentence", "line"], label="分割方法", value="sentence") ], outputs=[ gr.JSON(label="分析結果"), gr.Plot(label="スコアグラフ"), gr.Textbox(label="R18判定除外テキスト") ], title="テキスト分析API", description="テキストを入力し、R18判定と分析を行います。" ) # サーバーの起動 if __name__ == "__main__": iface.launch(server_name="0.0.0.0", server_port=7860)