PerryCheng614 commited on
Commit
a2bfb71
1 Parent(s): 776bb63

change back to wss

Browse files
Files changed (1) hide show
  1. app.py +44 -63
app.py CHANGED
@@ -1,79 +1,67 @@
1
  import gradio as gr
2
- import requests
 
3
  import json
 
 
 
4
  import os
5
- import time
6
 
7
- API_KEY = os.getenv("API_KEY")
8
  if not API_KEY:
9
- raise ValueError("API_KEY environment variable must be set")
10
 
11
- def process_image_stream(image_path, prompt, max_tokens=512):
12
  """
13
- Process image with streaming response via HTTP
14
  """
15
  if not image_path:
16
  yield "Please upload an image first."
17
  return
18
 
19
  try:
20
- # Read and prepare image file
21
- with open(image_path, 'rb') as img_file:
22
- files = {
23
- 'image': ('image.jpg', img_file, 'image/jpeg')
24
- }
25
- data = {
26
- 'prompt': prompt,
27
- 'task': 'instruct',
28
- 'max_tokens': max_tokens
29
- }
30
- headers = {
31
- 'X-API-Key': API_KEY
32
- }
33
 
34
- # Make streaming request
35
- response = requests.post(
36
- 'https://nexa-omni.nexa4ai.com/process-image/',
37
- files=files,
38
- data=data,
39
- headers=headers,
40
- stream=True
41
- )
 
42
 
43
- if response.status_code != 200:
44
- yield f"Error: Server returned status code {response.status_code}"
45
- return
46
-
47
  # Initialize response and token counter
48
- response_text = ""
49
  token_count = 0
50
 
51
- # Process the streaming response
52
- for line in response.iter_lines():
53
- if line:
54
- line = line.decode('utf-8')
55
- if line.startswith('data: '):
56
- try:
57
- data = json.loads(line[6:]) # Skip 'data: ' prefix
58
- if data["status"] == "generating":
59
- # Skip first three tokens if they match specific patterns
60
- if token_count < 3 and data["token"] in [" ", " \n", "\n", "<|im_start|>", "assistant"]:
61
- token_count += 1
62
- continue
63
- response_text += data["token"]
64
- gr.update(value=response_text)
65
- yield response_text
66
- time.sleep(0.01)
67
- elif data["status"] == "complete":
68
- break
69
- elif data["status"] == "error":
70
- yield f"Error: {data['error']}"
71
- break
72
- except json.JSONDecodeError:
73
  continue
 
 
 
 
 
 
 
 
 
74
 
75
  except Exception as e:
76
- yield f"Error processing request: {str(e)}"
77
 
78
  # Create Gradio interface
79
  demo = gr.Interface(
@@ -97,7 +85,6 @@ demo = gr.Interface(
97
  title="NEXA OmniVLM-968M",
98
  description=f"""
99
  Model Repo: <a href="https://huggingface.co/NexaAIDev/OmniVLM-968M">NexaAIDev/OmniVLM-968M</a>
100
-
101
  *Model updated on Nov 21, 2024\n
102
  Upload an image and ask questions about it. The model will analyze the image and provide detailed answers to your queries.
103
  """,
@@ -109,10 +96,4 @@ demo = gr.Interface(
109
  )
110
 
111
  if __name__ == "__main__":
112
- # Configure the queue for better streaming performance
113
- demo.queue(
114
- max_size=20,
115
- ).launch(
116
- server_name="0.0.0.0",
117
- server_port=7860,
118
- )
 
1
  import gradio as gr
2
+ import websockets
3
+ import asyncio
4
  import json
5
+ import base64
6
+ from PIL import Image
7
+ import io
8
  import os
 
9
 
10
+ API_KEY = os.getenv('API_KEY')
11
  if not API_KEY:
12
+ raise ValueError("API_KEY must be set in environment variables")
13
 
14
+ async def process_image_stream(image_path, prompt, max_tokens=512):
15
  """
16
+ Process image with streaming response via WebSocket
17
  """
18
  if not image_path:
19
  yield "Please upload an image first."
20
  return
21
 
22
  try:
23
+ # Read and convert image to base64
24
+ with Image.open(image_path) as img:
25
+ img = img.convert('RGB')
26
+ buffer = io.BytesIO()
27
+ img.save(buffer, format="JPEG")
28
+ base64_image = base64.b64encode(buffer.getvalue()).decode('utf-8')
 
 
 
 
 
 
 
29
 
30
+ # Connect to WebSocket
31
+ async with websockets.connect('wss://nexa-omni.nexa4ai.com/ws/process-image/?api_key=' + API_KEY) as websocket:
32
+ # Send image data and parameters as JSON
33
+ await websocket.send(json.dumps({
34
+ "image": f"data:image/jpeg;base64,{base64_image}",
35
+ "prompt": prompt,
36
+ "task": "instruct", # Fixed to instruct
37
+ "max_tokens": max_tokens
38
+ }))
39
 
 
 
 
 
40
  # Initialize response and token counter
41
+ response = ""
42
  token_count = 0
43
 
44
+ # Receive streaming response
45
+ async for message in websocket:
46
+ try:
47
+ data = json.loads(message)
48
+ if data["status"] == "generating":
49
+ # Skip first three tokens if they match specific patterns
50
+ if token_count < 3 and data["token"] in [" ", " \n", "\n", "<|im_start|>", "assistant"]:
51
+ token_count += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  continue
53
+ response += data["token"]
54
+ yield response
55
+ elif data["status"] == "complete":
56
+ break
57
+ elif data["status"] == "error":
58
+ yield f"Error: {data['error']}"
59
+ break
60
+ except json.JSONDecodeError:
61
+ continue
62
 
63
  except Exception as e:
64
+ yield f"Error connecting to server: {str(e)}"
65
 
66
  # Create Gradio interface
67
  demo = gr.Interface(
 
85
  title="NEXA OmniVLM-968M",
86
  description=f"""
87
  Model Repo: <a href="https://huggingface.co/NexaAIDev/OmniVLM-968M">NexaAIDev/OmniVLM-968M</a>
 
88
  *Model updated on Nov 21, 2024\n
89
  Upload an image and ask questions about it. The model will analyze the image and provide detailed answers to your queries.
90
  """,
 
96
  )
97
 
98
  if __name__ == "__main__":
99
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860)