PerryCheng614's picture
change back to wss
a2bfb71
raw
history blame
3.54 kB
import gradio as gr
import websockets
import asyncio
import json
import base64
from PIL import Image
import io
import os
API_KEY = os.getenv('API_KEY')
if not API_KEY:
raise ValueError("API_KEY must be set in environment variables")
async def process_image_stream(image_path, prompt, max_tokens=512):
"""
Process image with streaming response via WebSocket
"""
if not image_path:
yield "Please upload an image first."
return
try:
# Read and convert image to base64
with Image.open(image_path) as img:
img = img.convert('RGB')
buffer = io.BytesIO()
img.save(buffer, format="JPEG")
base64_image = base64.b64encode(buffer.getvalue()).decode('utf-8')
# Connect to WebSocket
async with websockets.connect('wss://nexa-omni.nexa4ai.com/ws/process-image/?api_key=' + API_KEY) as websocket:
# Send image data and parameters as JSON
await websocket.send(json.dumps({
"image": f"data:image/jpeg;base64,{base64_image}",
"prompt": prompt,
"task": "instruct", # Fixed to instruct
"max_tokens": max_tokens
}))
# Initialize response and token counter
response = ""
token_count = 0
# Receive streaming response
async for message in websocket:
try:
data = json.loads(message)
if data["status"] == "generating":
# Skip first three tokens if they match specific patterns
if token_count < 3 and data["token"] in [" ", " \n", "\n", "<|im_start|>", "assistant"]:
token_count += 1
continue
response += data["token"]
yield response
elif data["status"] == "complete":
break
elif data["status"] == "error":
yield f"Error: {data['error']}"
break
except json.JSONDecodeError:
continue
except Exception as e:
yield f"Error connecting to server: {str(e)}"
# Create Gradio interface
demo = gr.Interface(
fn=process_image_stream,
inputs=[
gr.Image(type="filepath", label="Upload Image"),
gr.Textbox(
label="Question",
placeholder="Ask a question about the image...",
value="Describe this image"
),
gr.Slider(
minimum=50,
maximum=200,
value=200,
step=1,
label="Max Tokens"
)
],
outputs=gr.Textbox(label="Response", interactive=False),
title="NEXA OmniVLM-968M",
description=f"""
Model Repo: <a href="https://huggingface.co/NexaAIDev/OmniVLM-968M">NexaAIDev/OmniVLM-968M</a>
*Model updated on Nov 21, 2024\n
Upload an image and ask questions about it. The model will analyze the image and provide detailed answers to your queries.
""",
examples=[
["example_images/example_1.jpg", "What kind of cat is this?", 128],
["example_images/example_2.jpg", "What color is this dress? ", 128],
["example_images/example_3.jpg", "What is this image about?", 128],
]
)
if __name__ == "__main__":
demo.queue().launch(server_name="0.0.0.0", server_port=7860)