Spaces:
Build error
Build error
import argparse | |
import queue | |
import sys | |
import uuid | |
from functools import partial | |
import numpy as np | |
import tritonclient.grpc as grpcclient | |
from tritonclient.utils import InferenceServerException | |
import gradio as gr | |
from functools import wraps | |
#### | |
from PIL import Image | |
import base64 | |
import io | |
##### | |
from http.server import HTTPServer, SimpleHTTPRequestHandler | |
import socket | |
#### | |
import os | |
import uuid | |
#### | |
class UserData: | |
def __init__(self): | |
self._completed_requests = queue.Queue() | |
def callback(user_data, result, error): | |
if error: | |
user_data._completed_requests.put(error) | |
else: | |
user_data._completed_requests.put(result) | |
def make_a_try(img_url,text): | |
model_name = 'ensemble_mllm' | |
user_data = UserData() | |
sequence_id = 100 | |
int_sequence_id0 = sequence_id | |
result_list=[] | |
with grpcclient.InferenceServerClient( | |
url="10.199.14.151:8001", verbose = False | |
) as triton_client: | |
try: | |
# Establish stream | |
triton_client.start_stream( | |
callback=partial(callback, user_data), | |
stream_timeout=None, | |
) | |
# Create the tensor for INPUT | |
inputs = [] | |
img_url_bytes = img_url.encode("utf-8") | |
img_url_bytes = np.array(img_url_bytes, dtype=bytes) | |
img_url_bytes = img_url_bytes.reshape([1, -1]) | |
inputs.append(grpcclient.InferInput('IMAGE_URL', img_url_bytes.shape, "BYTES")) | |
inputs[0].set_data_from_numpy(img_url_bytes) | |
text_bytes = text.encode("utf-8") | |
text_bytes = np.array(text_bytes, dtype=bytes) | |
text_bytes = text_bytes.reshape([1, -1]) | |
# text_input = np.expand_dims(text_bytes, axis=0) | |
text_input = text_bytes | |
inputs.append(grpcclient.InferInput('TEXT', text_input.shape, "BYTES")) | |
inputs[1].set_data_from_numpy(text_input) | |
outputs = [] | |
outputs.append(grpcclient.InferRequestedOutput("OUTPUT")) | |
# Issue the asynchronous sequence inference. | |
triton_client.async_stream_infer( | |
model_name=model_name, | |
inputs=inputs, | |
outputs=outputs, | |
request_id="{}".format(sequence_id), | |
sequence_id=sequence_id, | |
sequence_start=True, | |
sequence_end=True, | |
) | |
######hd | |
except InferenceServerException as error: | |
print(error) | |
# sys.exit(1) | |
# continue | |
return "" | |
# Retrieve results... | |
recv_count = 0 | |
while True: | |
try: | |
data_item = user_data._completed_requests.get(timeout=5) | |
except Exception as e: | |
break | |
# data_item = user_data._completed_requests.get() | |
if type(data_item) == InferenceServerException: | |
print('InferenceServerException: ', data_item) | |
# sys.exit(1) | |
return "" | |
this_id = data_item.get_response().id.split("_")[0] | |
if int(this_id) != int_sequence_id0: | |
print("unexpected sequence id returned by the server: {}".format(this_id)) | |
# sys.exit(1) | |
return "" | |
#### | |
result = data_item.as_numpy("OUTPUT") | |
if len(result[0][0])==0: | |
break | |
#### | |
result_list.append(data_item.as_numpy("OUTPUT")) | |
recv_count = recv_count + 1 | |
result_str = ''.join([item[0][0].decode('utf-8') for item in result_list]) | |
return result_str | |
def greet(image, text): | |
###save img | |
static_path = f"/workdir/yanghandi/gradio_demo/static" | |
# 将图片转换为字节流 | |
img_byte_arr = io.BytesIO() | |
try: | |
image.save(img_byte_arr, format='JPEG') | |
except Exception: | |
return "" | |
img_byte_arr = img_byte_arr.getvalue() | |
# 为图片生成一个唯一的文件名 | |
# filename = "image_" + str(os.getpid()) + ".jpg" #uuid | |
unique_id = uuid.uuid4() | |
filename = f"image_{unique_id}.jpg" | |
filepath = os.path.join(static_path, filename) | |
# 将字节流写入文件 | |
with open(filepath, 'wb') as f: | |
f.write(img_byte_arr) | |
img_url = f"http://10.99.5.48:8080/file=static/" + filename | |
# img_url = PIL_to_URL(img_url) | |
# img_url = "http://10.99.5.48:8080/file=static/0000.jpeg" | |
result = make_a_try(img_url,text) | |
# print(result) | |
return result | |
# def greet_example(image, text): | |
# ###save img | |
# # filename = image | |
# # static_path = "/workdir/yanghandi/gradio_demo/static" | |
# img_url = "http://10.99.5.48:8080/file=static/0000.jpeg" | |
# # img_url = PIL_to_URL(img_url) | |
# # img_url = "http://10.99.5.48:8080/file=static/0000.jpeg" | |
# result = make_a_try(img_url,text) | |
# # print(result) | |
# return result | |
def clear_output(): | |
return "" | |
def get_example(): | |
return [ | |
[f"/workdir/yanghandi/gradio_demo/static/0001.jpg", f"图中的人物是谁"] | |
] | |
if __name__ == "__main__": | |
param_info = {} | |
# param_info['appkey'] = "com.sankuai.automl.serving" | |
param_info['appkey'] = "10.199.14.151:8001" | |
# param_info['remote_appkey'] = "com.sankuai.automl.chat3" | |
param_info['remote_appkey'] = "10.199.14.151:8001" | |
param_info['model_name'] = 'ensemble_mllm' | |
param_info['model_version'] = "1" | |
param_info['time_out'] = 60000 | |
param_info['server_targets'] = [] | |
param_info['outputs'] = 'response' | |
gr.set_static_paths(paths=["static/"]) | |
with gr.Blocks(title='demo') as demo: | |
gr.Markdown("# 自研模型测试demo") | |
gr.Markdown("尝试使用该demo,上传图片并开始讨论它,或者尝试下面的例子") | |
with gr.Row(): | |
with gr.Column(): | |
# imagebox = gr.Image(value="static/0000.jpeg",type="pil") | |
imagebox = gr.Image(type="pil") | |
promptbox = gr.Textbox(label = "prompt") | |
with gr.Column(): | |
output = gr.Textbox(label = "output") | |
with gr.Row(): | |
submit = gr.Button("submit") | |
clear = gr.Button("clear") | |
submit.click(fn=greet,inputs=[imagebox, promptbox],outputs=[output]) | |
clear.click(fn=clear_output, inputs=[], outputs=[output]) | |
gr.Markdown("# example") | |
gr.Examples( | |
examples = get_example(), | |
fn = greet, | |
inputs=[imagebox, promptbox], | |
outputs = [output], | |
cache_examples = True | |
) | |
demo.launch(server_name="0.0.0.0", server_port=8080, debug=True, share=True) | |
# img_url = f"https://s3plus.sankuai.com/automl-pkgs/0000.jpeg" | |
# text = f"详细描述一下这张图片" | |
# greet(img_url,text) | |