taka-yayoi commited on
Commit
ac8da7f
·
verified ·
1 Parent(s): 089a0d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -57
app.py CHANGED
@@ -1,63 +1,105 @@
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  demo = gr.ChatInterface(
46
  respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
- ],
 
59
  )
60
 
61
-
62
  if __name__ == "__main__":
63
- demo.launch()
 
1
+ import itertools
2
  import gradio as gr
3
+ import requests
4
+ import os
5
+ from gradio.themes.utils import sizes
6
+ import json
7
+ import pandas as pd
8
+
9
+ import base64
10
+ import io
11
+ from PIL import Image
12
+ import numpy as np
13
+
14
+ def respond(message, history):
15
+
16
+ if len(message.strip()) == 0:
17
+ return "質問を入力してください"
18
+
19
+ local_token = os.getenv('API_TOKEN')
20
+ local_endpoint = os.getenv('API_ENDPOINT')
21
+
22
+ if local_token is None or local_endpoint is None:
23
+ return "ERROR missing env variables"
24
+
25
+ # Add your API token to the headers
26
+ headers = {
27
+ 'Content-Type': 'application/json',
28
+ 'Authorization': f'Bearer {local_token}'
29
+ }
30
+
31
+ #prompt = list(itertools.chain.from_iterable(history))
32
+ #prompt.append(message)
33
+
34
+ # プロンプトの作成
35
+ prompt = pd.DataFrame(
36
+ {"prompt": [message], "num_inference_steps": 25}
37
+ )
38
+
39
+ print(prompt)
40
+ ds_dict = {"dataframe_split": prompt.to_dict(orient="split")}
41
+ data_json = json.dumps(ds_dict, allow_nan=True)
42
+
43
+ embed_image_markdown = ""
44
+
45
+ try:
46
+ # モデルサービングエンドポイントに問い合わせ
47
+ response = requests.request(method="POST", headers=headers, url=local_endpoint, data=data_json)
48
+ response_data = response.json()
49
+ #print(response_data["predictions"])
50
+
51
+ # numpy arrayに変換
52
+ im_array = np.array(response_data["predictions"], dtype=np.uint8)
53
+ #print(im_array)
54
+ # 画像に変換
55
+ im = Image.fromarray(im_array, 'RGB')
56
+
57
+
58
+ # debug
59
+ #image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/e/ec/Mona_Lisa,_by_Leonardo_da_Vinci,_from_C2RMF_retouched.jpg/687px-Mona_Lisa,_by_Leonardo_da_Vinci,_from_C2RMF_retouched.jpg"
60
+ #print("image_url:", image_url)
61
+ #im = Image.open(io.BytesIO(requests.get(image_url).content))
62
+ #numpydata = np.asarray(im)
63
+
64
+ rawBytes = io.BytesIO()
65
+ im.save(rawBytes, "PNG")
66
+ rawBytes.seek(0) # ファイルの先頭に移動
67
+ # base64にエンコード
68
+ image_encoded = base64.b64encode(rawBytes.read()).decode('ascii')
69
+ #print(image_encoded)
70
+
71
+ # マークダウンに埋め込み
72
+ embed_image_markdown = f"![](data:image/png;base64,{image_encoded})"
73
+ #print(embed_image_markdown)
74
+
75
+ except Exception as error:
76
+ response_data = f"ERROR status_code: {type(error).__name__}"
77
+ #+ str(response.status_code) + " response:" + response.text
78
+
79
+ return embed_image_markdown
80
+
81
+
82
+ theme = gr.themes.Soft(
83
+ text_size=sizes.text_sm,radius_size=sizes.radius_sm, spacing_size=sizes.spacing_sm,
84
+ )
85
+
86
+
87
  demo = gr.ChatInterface(
88
  respond,
89
+ chatbot=gr.Chatbot(show_label=False, container=False, show_copy_button=True, bubble_full_width=True),
90
+ textbox=gr.Textbox(placeholder="質問を入力してください",
91
+ container=False, scale=7),
92
+ title="Databricks QAチャットボット",
93
+ description="TBD",
94
+ examples=[["Databricksクラスターとは?"],
95
+ ["Unity Catalogの有効化方法"],
96
+ ["リネージの保持期間"],],
97
+ cache_examples=False,
98
+ theme=theme,
99
+ retry_btn=None,
100
+ undo_btn=None,
101
+ clear_btn="Clear",
102
  )
103
 
 
104
  if __name__ == "__main__":
105
+ demo.launch()