import gradio as gr import spaces from mistral_inference.transformer import Transformer from mistral_inference.generate import generate from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from mistral_common.protocol.instruct.messages import UserMessage, TextChunk, ImageURLChunk from mistral_common.protocol.instruct.request import ChatCompletionRequest from huggingface_hub import snapshot_download from pathlib import Path # モデルのダウンロードと準備 mistral_models_path = Path.home().joinpath('mistral_models', 'Pixtral') mistral_models_path.mkdir(parents=True, exist_ok=True) snapshot_download(repo_id="mistral-community/pixtral-12b-240910", allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"], local_dir=mistral_models_path) # トークナイザーとモデルのロード tokenizer = MistralTokenizer.from_file(f"{mistral_models_path}/tekken.json") model = Transformer.from_folder(mistral_models_path) # 推論処理 @spaces.GPU def mistral_inference(prompt, image_url): completion_request = ChatCompletionRequest( messages=[UserMessage(content=[ImageURLChunk(image_url=image_url), TextChunk(text=prompt)])] ) encoded = tokenizer.encode_chat_completion(completion_request) images = encoded.images tokens = encoded.tokens out_tokens, _ = generate([tokens], model, images=[images], max_tokens=1024, temperature=0.35, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id) result = tokenizer.decode(out_tokens[0]) return result # Gradio インターフェース def process_input(text, image_url): result = mistral_inference(text, image_url) return result, image_url with gr.Blocks() as demo: gr.Markdown("## Pixtralモデルによる画像説明生成") with gr.Row(): text_input = gr.Textbox(label="テキストプロンプト", placeholder="例: Describe the image.") image_input = gr.Textbox(label="画像URL", placeholder="例: https://example.com/image.png") result_output = gr.Textbox(label="モデルの出力結果", lines=8, max_lines=20) # 高さを500ピクセルに相当するように調整 image_output = gr.Image(label="入力された画像", type="auto") # 入力画像URLを表示するための場所 submit_button = gr.Button("推論を実行") # ボタンをクリックすると、モデルの結果と画像を表示 submit_button.click(process_input, inputs=[text_input, image_input], outputs=[result_output, image_output]) demo.launch()