aixsatoshi commited on
Commit
8ff201d
·
verified ·
1 Parent(s): aaa844d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -42
app.py CHANGED
@@ -1,45 +1,30 @@
1
  import gradio as gr
 
2
  from mistral_inference.transformer import Transformer
3
  from mistral_inference.generate import generate
4
  from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
5
- from mistral_common.protocol.instruct.messages import UserMessage, TextChunk, ImageURLChunk, ImageChunk
6
  from mistral_common.protocol.instruct.request import ChatCompletionRequest
7
  from huggingface_hub import snapshot_download
8
  from pathlib import Path
9
- import base64
10
- import spaces
11
 
12
  # モデルのダウンロードと準備
13
  mistral_models_path = Path.home().joinpath('mistral_models', 'Pixtral')
14
  mistral_models_path.mkdir(parents=True, exist_ok=True)
15
 
16
- snapshot_download(repo_id="mistralai/Pixtral-12B-2409",
17
  allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"],
18
  local_dir=mistral_models_path)
19
 
20
-
21
  # トークナイザーとモデルのロード
22
  tokenizer = MistralTokenizer.from_file(f"{mistral_models_path}/tekken.json")
23
  model = Transformer.from_folder(mistral_models_path)
24
 
25
- # 画像ファイルをbase64に変換するヘルパー関数
26
- def image_to_base64(image_path):
27
- with open(image_path, "rb") as image_file:
28
- encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
29
- return encoded_string
30
-
31
  # 推論処理
32
  @spaces.GPU
33
- def mistral_inference(prompt, image_url=None, image_file=None):
34
- if image_file is not None:
35
- # 画像ファイルがアップロードされた場合
36
- image_chunk = ImageChunk(image_base64=image_to_base64(image_file))
37
- else:
38
- # 画像URLが指定された場合
39
- image_chunk = ImageURLChunk(image_url=image_url)
40
-
41
  completion_request = ChatCompletionRequest(
42
- messages=[UserMessage(content=[image_chunk, TextChunk(text=prompt)])]
43
  )
44
 
45
  encoded = tokenizer.encode_chat_completion(completion_request)
@@ -57,8 +42,7 @@ def get_labels(language):
57
  'en': {
58
  'title': "Pixtral Model Image Description",
59
  'text_prompt': "Text Prompt",
60
- 'image_url': "Image URL (or leave blank if uploading an image)",
61
- 'image_upload': "Upload Image",
62
  'output': "Model Output",
63
  'image_display': "Input Image",
64
  'submit': "Run Inference"
@@ -66,8 +50,7 @@ def get_labels(language):
66
  'zh': {
67
  'title': "Pixtral模型图像描述",
68
  'text_prompt': "文本提示",
69
- 'image_url': "图片网址 (如果上传图片,请留空)",
70
- 'image_upload': "上传图片",
71
  'output': "模型输出",
72
  'image_display': "输入图片",
73
  'submit': "运行推理"
@@ -75,8 +58,7 @@ def get_labels(language):
75
  'jp': {
76
  'title': "Pixtralモデルによる画像説明生成",
77
  'text_prompt': "テキストプロンプト",
78
- 'image_url': "画像URL(画像をアップロードする場合は空白)",
79
- 'image_upload': "画像をアップロード",
80
  'output': "モデルの出力結果",
81
  'image_display': "入力された画像",
82
  'submit': "推論を実行"
@@ -85,19 +67,13 @@ def get_labels(language):
85
  return labels[language]
86
 
87
  # Gradioインターフェース
88
- def process_input(text, image_url, image_file):
89
- if image_file is not None:
90
- result = mistral_inference(text, image_file=image_file)
91
- image_display = f'<img src="data:image/png;base64,{image_to_base64(image_file)}" alt="Input Image" width="300">'
92
- else:
93
- result = mistral_inference(text, image_url=image_url)
94
- image_display = f'<img src="{image_url}" alt="Input Image" width="300">'
95
-
96
- return result, image_display
97
 
98
  def update_ui(language):
99
  labels = get_labels(language)
100
- return labels['title'], labels['text_prompt'], labels['image_url'], labels['image_upload'], labels['output'], labels['image_display'], labels['submit']
101
 
102
  with gr.Blocks() as demo:
103
  language_choice = gr.Dropdown(choices=['en', 'zh', 'jp'], label="Select Language", value='en')
@@ -105,22 +81,20 @@ with gr.Blocks() as demo:
105
  title = gr.Markdown("## Pixtral Model Image Description")
106
  with gr.Row():
107
  text_input = gr.Textbox(label="Text Prompt", placeholder="e.g. Describe the image.")
 
108
 
109
- image_url_input = gr.Textbox(label="Image URL (or leave blank if uploading an image)", placeholder="e.g. https://example.com/image.png")
110
- image_file_input = gr.Image(label="Upload Image", type="filepath", optional=True)
111
-
112
  result_output = gr.Textbox(label="Model Output", lines=8, max_lines=20) # 高さ500ピクセルに相当するように調整
113
  image_output = gr.HTML(label="Input Image") # 入力画像URLを表示するための場所
114
 
115
  submit_button = gr.Button("Run Inference")
116
 
117
- submit_button.click(process_input, inputs=[text_input, image_url_input, image_file_input], outputs=[result_output, image_output])
118
 
119
  # 言語変更時にUIラベルを更新
120
  language_choice.change(
121
  fn=update_ui,
122
  inputs=[language_choice],
123
- outputs=[title, text_input, image_url_input, image_file_input, result_output, image_output, submit_button]
124
  )
125
 
126
- demo.launch()
 
1
  import gradio as gr
2
+ import spaces
3
  from mistral_inference.transformer import Transformer
4
  from mistral_inference.generate import generate
5
  from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
6
+ from mistral_common.protocol.instruct.messages import UserMessage, TextChunk, ImageURLChunk
7
  from mistral_common.protocol.instruct.request import ChatCompletionRequest
8
  from huggingface_hub import snapshot_download
9
  from pathlib import Path
 
 
10
 
11
  # モデルのダウンロードと準備
12
  mistral_models_path = Path.home().joinpath('mistral_models', 'Pixtral')
13
  mistral_models_path.mkdir(parents=True, exist_ok=True)
14
 
15
+ snapshot_download(repo_id="mistral-community/pixtral-12b-240910",
16
  allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"],
17
  local_dir=mistral_models_path)
18
 
 
19
  # トークナイザーとモデルのロード
20
  tokenizer = MistralTokenizer.from_file(f"{mistral_models_path}/tekken.json")
21
  model = Transformer.from_folder(mistral_models_path)
22
 
 
 
 
 
 
 
23
  # 推論処理
24
  @spaces.GPU
25
+ def mistral_inference(prompt, image_url):
 
 
 
 
 
 
 
26
  completion_request = ChatCompletionRequest(
27
+ messages=[UserMessage(content=[ImageURLChunk(image_url=image_url), TextChunk(text=prompt)])]
28
  )
29
 
30
  encoded = tokenizer.encode_chat_completion(completion_request)
 
42
  'en': {
43
  'title': "Pixtral Model Image Description",
44
  'text_prompt': "Text Prompt",
45
+ 'image_url': "Image URL",
 
46
  'output': "Model Output",
47
  'image_display': "Input Image",
48
  'submit': "Run Inference"
 
50
  'zh': {
51
  'title': "Pixtral模型图像描述",
52
  'text_prompt': "文本提示",
53
+ 'image_url': "图片网址",
 
54
  'output': "模型输出",
55
  'image_display': "输入图片",
56
  'submit': "运行推理"
 
58
  'jp': {
59
  'title': "Pixtralモデルによる画像説明生成",
60
  'text_prompt': "テキストプロンプト",
61
+ 'image_url': "画像URL",
 
62
  'output': "モデルの出力結果",
63
  'image_display': "入力された画像",
64
  'submit': "推論を実行"
 
67
  return labels[language]
68
 
69
  # Gradioインターフェース
70
+ def process_input(text, image_url):
71
+ result = mistral_inference(text, image_url)
72
+ return result, f'<img src="{image_url}" alt="Input Image" width="300">'
 
 
 
 
 
 
73
 
74
  def update_ui(language):
75
  labels = get_labels(language)
76
+ return labels['title'], labels['text_prompt'], labels['image_url'], labels['output'], labels['image_display'], labels['submit']
77
 
78
  with gr.Blocks() as demo:
79
  language_choice = gr.Dropdown(choices=['en', 'zh', 'jp'], label="Select Language", value='en')
 
81
  title = gr.Markdown("## Pixtral Model Image Description")
82
  with gr.Row():
83
  text_input = gr.Textbox(label="Text Prompt", placeholder="e.g. Describe the image.")
84
+ image_input = gr.Textbox(label="Image URL", placeholder="e.g. https://example.com/image.png")
85
 
 
 
 
86
  result_output = gr.Textbox(label="Model Output", lines=8, max_lines=20) # 高さ500ピクセルに相当するように調整
87
  image_output = gr.HTML(label="Input Image") # 入力画像URLを表示するための場所
88
 
89
  submit_button = gr.Button("Run Inference")
90
 
91
+ submit_button.click(process_input, inputs=[text_input, image_input], outputs=[result_output, image_output])
92
 
93
  # 言語変更時にUIラベルを更新
94
  language_choice.change(
95
  fn=update_ui,
96
  inputs=[language_choice],
97
+ outputs=[title, text_input, image_input, result_output, image_output, submit_button]
98
  )
99
 
100
+ demo.launch()