Spaces:
Running
Running
import gradio as gr | |
import websockets | |
import asyncio | |
import json | |
import base64 | |
from PIL import Image | |
import io | |
import os | |
API_KEY = os.getenv('API_KEY') | |
if not API_KEY: | |
raise ValueError("API_KEY must be set in environment variables") | |
async def process_image_stream(image_path, prompt, max_tokens=512): | |
""" | |
Process image with streaming response via WebSocket | |
""" | |
if not image_path: | |
yield "Please upload an image first." | |
return | |
try: | |
# Read and convert image to base64 | |
with Image.open(image_path) as img: | |
img = img.convert('RGB') | |
buffer = io.BytesIO() | |
img.save(buffer, format="JPEG") | |
base64_image = base64.b64encode(buffer.getvalue()).decode('utf-8') | |
# Connect to WebSocket | |
async with websockets.connect('wss://nexa-omni.nexa4ai.com/ws/process-image/?api_key=' + API_KEY) as websocket: | |
# Send image data and parameters as JSON | |
await websocket.send(json.dumps({ | |
"image": f"data:image/jpeg;base64,{base64_image}", | |
"prompt": prompt, | |
"task": "instruct", # Fixed to instruct | |
"max_tokens": max_tokens | |
})) | |
# Initialize response and token counter | |
response = "" | |
token_count = 0 | |
# Receive streaming response | |
async for message in websocket: | |
try: | |
data = json.loads(message) | |
if data["status"] == "generating": | |
# Skip first three tokens if they match specific patterns | |
if token_count < 3 and data["token"] in [" ", " \n", "\n", "<|im_start|>", "assistant"]: | |
token_count += 1 | |
continue | |
response += data["token"] | |
yield response | |
elif data["status"] == "complete": | |
break | |
elif data["status"] == "error": | |
yield f"Error: {data['error']}" | |
break | |
except json.JSONDecodeError: | |
continue | |
except Exception as e: | |
yield f"Error connecting to server: {str(e)}" | |
# Create Gradio interface | |
demo = gr.Interface( | |
fn=process_image_stream, | |
inputs=[ | |
gr.Image(type="filepath", label="Upload Image"), | |
gr.Textbox( | |
label="Question", | |
placeholder="Ask a question about the image...", | |
value="Describe this image" | |
), | |
gr.Slider( | |
minimum=50, | |
maximum=200, | |
value=200, | |
step=1, | |
label="Max Tokens" | |
) | |
], | |
outputs=gr.Textbox(label="Response", interactive=False), | |
title="NEXA OmniVLM-968M", | |
description=f""" | |
Model Repo: <a href="https://huggingface.co/NexaAIDev/OmniVLM-968M">NexaAIDev/OmniVLM-968M</a> | |
*Model updated on Nov 21, 2024\n | |
Upload an image and ask questions about it. The model will analyze the image and provide detailed answers to your queries. | |
""", | |
examples=[ | |
["example_images/example_1.jpg", "What kind of cat is this?", 128], | |
["example_images/example_2.jpg", "What color is this dress? ", 128], | |
["example_images/example_3.jpg", "What is this image about?", 128], | |
] | |
) | |
if __name__ == "__main__": | |
demo.queue().launch(server_name="0.0.0.0", server_port=7860) | |