File size: 3,536 Bytes
76578bc
a2bfb71
 
76578bc
a2bfb71
 
 
70a6a62
76578bc
a2bfb71
70a6a62
a2bfb71
70a6a62
a2bfb71
76578bc
a2bfb71
76578bc
 
 
 
 
 
a2bfb71
 
 
 
 
 
76578bc
a2bfb71
 
 
 
 
 
 
 
 
76578bc
6e0c709
a2bfb71
6e0c709
76578bc
a2bfb71
 
 
 
 
 
 
 
6e0c709
a2bfb71
 
 
 
 
 
 
 
 
76578bc
 
a2bfb71
76578bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6978e20
f651569
6978e20
3388a44
76578bc
 
 
6e0c709
 
 
76578bc
 
 
 
a2bfb71
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
import websockets
import asyncio
import json
import base64
from PIL import Image
import io
import os

API_KEY = os.getenv('API_KEY')
if not API_KEY:
    raise ValueError("API_KEY must be set in environment variables")

async def process_image_stream(image_path, prompt, max_tokens=512):
    """
    Process image with streaming response via WebSocket
    """
    if not image_path:
        yield "Please upload an image first."
        return
    
    try:
        # Read and convert image to base64
        with Image.open(image_path) as img:
            img = img.convert('RGB')
            buffer = io.BytesIO()
            img.save(buffer, format="JPEG")
            base64_image = base64.b64encode(buffer.getvalue()).decode('utf-8')
            
        # Connect to WebSocket
        async with websockets.connect('wss://nexa-omni.nexa4ai.com/ws/process-image/?api_key=' + API_KEY) as websocket:
            # Send image data and parameters as JSON
            await websocket.send(json.dumps({
                "image": f"data:image/jpeg;base64,{base64_image}",
                "prompt": prompt,
                "task": "instruct",  # Fixed to instruct
                "max_tokens": max_tokens
            }))
            
            # Initialize response and token counter
            response = ""
            token_count = 0
            
            # Receive streaming response
            async for message in websocket:
                try:
                    data = json.loads(message)
                    if data["status"] == "generating":
                        # Skip first three tokens if they match specific patterns
                        if token_count < 3 and data["token"] in [" ", " \n", "\n", "<|im_start|>", "assistant"]:
                            token_count += 1
                            continue
                        response += data["token"]
                        yield response
                    elif data["status"] == "complete":
                        break
                    elif data["status"] == "error":
                        yield f"Error: {data['error']}"
                        break
                except json.JSONDecodeError:
                    continue
                
    except Exception as e:
        yield f"Error connecting to server: {str(e)}"

# Create Gradio interface
demo = gr.Interface(
    fn=process_image_stream,
    inputs=[
        gr.Image(type="filepath", label="Upload Image"),
        gr.Textbox(
            label="Question", 
            placeholder="Ask a question about the image...",
            value="Describe this image"
        ),
        gr.Slider(
            minimum=50,
            maximum=200,
            value=200,
            step=1,
            label="Max Tokens"
        )
    ],
    outputs=gr.Textbox(label="Response", interactive=False),
    title="NEXA OmniVLM-968M",
    description=f"""
    Model Repo: <a href="https://huggingface.co/NexaAIDev/OmniVLM-968M">NexaAIDev/OmniVLM-968M</a>
    *Model updated on Nov 21, 2024\n
    Upload an image and ask questions about it. The model will analyze the image and provide detailed answers to your queries.
    """,
    examples=[
        ["example_images/example_1.jpg", "What kind of cat is this?", 128],
        ["example_images/example_2.jpg", "What color is this dress? ", 128],
        ["example_images/example_3.jpg", "What is this image about?", 128],
    ]
)

if __name__ == "__main__":
    demo.queue().launch(server_name="0.0.0.0", server_port=7860)