import torch from transformers import AutoTokenizer, AutoModelForCausalLM from flask import Flask, request, jsonify, render_template_string import time # Flaskアプリケーションの設定 app = Flask(__name__) # デバイスの設定 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # トークナイザーとモデルの読み込み tokenizer = AutoTokenizer.from_pretrained("inu-ai/alpaca-guanaco-japanese-gpt-1b", use_fast=False) model = AutoModelForCausalLM.from_pretrained("inu-ai/alpaca-guanaco-japanese-gpt-1b").to(device) # 定数 MAX_ASSISTANT_LENGTH = 100 MAX_INPUT_LENGTH = 1024 INPUT_PROMPT = r'\n以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。\n[SEP]\n指示:\n{instruction}\n[SEP]\n入力:\n{input}\n[SEP]\n応答:\n' NO_INPUT_PROMPT = r'\n以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。\n[SEP]\n指示:\n{instruction}\n[SEP]\n応答:\n' # HTMLテンプレート HTML_TEMPLATE = """ Chat Interface

Chat Interface

""" def prepare_input(role_instruction, conversation_history, new_conversation): """入力テキストを整形する関数""" instruction = "".join([f"{text}\n" for text in role_instruction]) instruction += "\n".join(conversation_history) input_text = f"User:{new_conversation}" return INPUT_PROMPT.format(instruction=instruction, input=input_text) def format_output(output): """生成された出力を整形する関数""" return output.lstrip("").rstrip("").replace("[SEP]", "").replace("\\n", "\n") def trim_conversation_history(conversation_history, max_length): """会話履歴を最大長に収めるために調整する関数""" while len(conversation_history) > 2 and sum([len(tokenizer.encode(text, add_special_tokens=False)) for text in conversation_history]) + max_length > MAX_INPUT_LENGTH: conversation_history.pop(0) conversation_history.pop(0) return conversation_history def generate_response(role_instruction, conversation_history, new_conversation): """新しい会話に対する応答を生成する関数""" conversation_history = trim_conversation_history(conversation_history, MAX_ASSISTANT_LENGTH) input_text = prepare_input(role_instruction, conversation_history, new_conversation) token_ids = tokenizer.encode(input_text, add_special_tokens=False, return_tensors="pt") with torch.no_grad(): output_ids = model.generate( token_ids.to(model.device), min_length=len(token_ids[0]), max_length=min(MAX_INPUT_LENGTH, len(token_ids[0]) + MAX_ASSISTANT_LENGTH), temperature=0.7, do_sample=True, pad_token_id=tokenizer.pad_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, bad_words_ids=[[tokenizer.unk_token_id]] ) output = tokenizer.decode(output_ids.tolist()[0]) formatted_output_all = format_output(output) response = f"Assistant:{formatted_output_all.split('応答:')[-1].strip()}" conversation_history.append(f"User:{new_conversation}".replace("\n", "\\n")) conversation_history.append(response.replace("\n", "\\n")) return formatted_output_all, response @app.route('/') def home(): """ホームページをレンダリング""" return render_template_string(HTML_TEMPLATE) @app.route('/generate', methods=['POST']) def generate(): """Flaskエンドポイント: /generate""" data = request.json role_instruction = data.get('role_instruction', []) conversation_history = data.get('conversation_history', []) new_conversation = data.get('new_conversation', "") if not role_instruction or not new_conversation: return jsonify({"error": "role_instruction and new_conversation are required fields"}), 400 formatted_output_all, response = generate_response(role_instruction, conversation_history, new_conversation) return jsonify({"response": response, "conversation_history": conversation_history}) if __name__ == '__main__': app.run(debug=True, host="0.0.0.0", port=7860)