File size: 3,288 Bytes
76578bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e0c709
76578bc
6e0c709
76578bc
 
 
 
 
 
6e0c709
 
 
 
76578bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3388a44
76578bc
 
 
6e0c709
 
 
76578bc
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import gradio as gr
import websockets
import asyncio
import json
import base64
from PIL import Image
import io

async def process_image_stream(image_path, prompt, max_tokens=512):
    """
    Process image with streaming response via WebSocket
    """
    if not image_path:
        yield "Please upload an image first."
        return
    
    try:
        # Read and convert image to base64
        with Image.open(image_path) as img:
            img = img.convert('RGB')
            buffer = io.BytesIO()
            img.save(buffer, format="JPEG")
            base64_image = base64.b64encode(buffer.getvalue()).decode('utf-8')
            
        # Connect to WebSocket
        async with websockets.connect('wss://nexa-omni.nexa4ai.com/ws/process-image/') as websocket:
            # Send image data and parameters as JSON
            await websocket.send(json.dumps({
                "image": f"data:image/jpeg;base64,{base64_image}",
                "prompt": prompt,
                "task": "instruct",  # Fixed to instruct
                "max_tokens": max_tokens
            }))
            
            # Initialize response and token counter
            response = ""
            token_count = 0
            
            # Receive streaming response
            async for message in websocket:
                try:
                    data = json.loads(message)
                    if data["status"] == "generating":
                        # Skip first three tokens if they match specific patterns
                        if token_count < 3 and data["token"] in [" ", " \n", "\n", "<|im_start|>", "assistant"]:
                            token_count += 1
                            continue
                        response += data["token"]
                        yield response
                    elif data["status"] == "complete":
                        break
                    elif data["status"] == "error":
                        yield f"Error: {data['error']}"
                        break
                except json.JSONDecodeError:
                    continue
                
    except Exception as e:
        yield f"Error connecting to server: {str(e)}"

# Create Gradio interface
demo = gr.Interface(
    fn=process_image_stream,
    inputs=[
        gr.Image(type="filepath", label="Upload Image"),
        gr.Textbox(
            label="Question", 
            placeholder="Ask a question about the image...",
            value="Describe this image"
        ),
        gr.Slider(
            minimum=50,
            maximum=200,
            value=200,
            step=1,
            label="Max Tokens"
        )
    ],
    outputs=gr.Textbox(label="Response", interactive=False),
    title="Nexa Omni Vision",
    description="""
    *Model updated on Nov 21, 2024\n
    Upload an image and ask questions about it. The model will analyze the image and provide detailed answers to your queries.
    """,
    examples=[
        ["example_images/example_1.jpg", "What kind of cat is this?", 128],
        ["example_images/example_2.jpg", "What color is this dress? ", 128],
        ["example_images/example_3.jpg", "What is this image about?", 128],
    ]
)

if __name__ == "__main__":
    demo.queue().launch(server_name="0.0.0.0", server_port=7860)