Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# | |
# This source code is licensed under the Chameleon License found in the | |
# LICENSE file in the root directory of this source tree. | |
import base64 | |
import io | |
import socket | |
import subprocess | |
import time | |
from functools import partial | |
import fastapi | |
import PIL | |
import pydantic | |
import redis.asyncio as async_redis | |
import uvicorn | |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, WebSocketException | |
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK | |
from chameleon.viewer.backend.data_types import ( | |
Content, | |
ContentType, | |
NoOptionsForComplete, | |
NoOptionsForFull, | |
NoOptionsForPartial, | |
NoOptionsForQueueStatus, | |
WSMessageType, | |
WSMultimodalMessage, | |
) | |
from chameleon.viewer.backend.models.abstract_model import ( | |
AbstractMultimodalGenerator, | |
StreamingImage, | |
) | |
from chameleon.viewer.backend.models.chameleon_distributed import AsyncRedisCounter | |
from chameleon.viewer.backend.utils import get_logger | |
logger = get_logger(__name__) | |
def nvidia_smi() -> str: | |
return subprocess.check_output(["nvidia-smi"], text=True) | |
async def await_generate_message(websocket: WebSocket) -> WSMultimodalMessage: | |
while True: | |
rec_message = await websocket.receive_json() | |
try: | |
maybe_message = WSMultimodalMessage.parse_obj(rec_message) | |
except pydantic.ValidationError: | |
maybe_message = None | |
logger.info("Got invalid message", maybe_message) | |
if maybe_message is not None: | |
return maybe_message | |
async def async_acquire_lock( | |
*, | |
websocket: WebSocket, | |
counter: AsyncRedisCounter, | |
lock: async_redis.lock.Lock, | |
interval=0.1, | |
status_interval=1, | |
hostname: str | None = None, | |
): | |
start = time.time() | |
await counter.add(1) | |
while True: | |
acquired = await lock.acquire(blocking_timeout=interval) | |
if acquired: | |
break | |
elapsed = time.time() - start | |
if elapsed > status_interval: | |
n_requests = await counter.count() | |
message = WSMultimodalMessage( | |
message_type=WSMessageType.QUEUE_STATUS, | |
content=[ | |
Content( | |
content_type=ContentType.TEXT, | |
content=f"n_requests={n_requests}", | |
) | |
], | |
options=NoOptionsForQueueStatus(), | |
debug_info={"hostname": hostname}, | |
).dict() | |
await websocket.send_json(message) | |
start = time.time() | |
await counter.sub(1) | |
COORDINATOR = "coordinator" | |
def web_app( | |
generator: AbstractMultimodalGenerator, | |
debug: bool = True, | |
redis_port: int | None = None, | |
) -> FastAPI: | |
app = FastAPI(debug=debug) | |
if redis_port is None: | |
redis_client = None | |
redis_lock = None | |
queue_counter = None | |
else: | |
redis_client = async_redis.Redis.from_url(f"redis://redis:{redis_port}") | |
redis_lock = async_redis.lock.Lock(redis_client, COORDINATOR) | |
queue_counter = AsyncRedisCounter(redis_client, "count_pending") | |
hostname = socket.gethostname() | |
def alive() -> dict: | |
return { | |
"status": "alive", | |
"hostname": hostname, | |
"nvidia-smi": nvidia_smi(), | |
} | |
async def websocket_chameleon_v2(*, websocket: WebSocket, client_id: str): | |
logger.info("Requested client_id: %s", client_id) | |
await websocket.accept() | |
logger.info("Client opened %s with generator id %s", client_id, id(generator)) | |
try: | |
while True: | |
generate_message = await await_generate_message(websocket) | |
logger.info("Got generate message: %s", str(generate_message)[:300]) | |
parsed_prompt = [] | |
for c in generate_message.content: | |
match c.content_type: | |
case ContentType.TEXT: | |
parsed_prompt.append(c.content) | |
case ContentType.IMAGE: | |
image_parts = c.content.split(",", 1) | |
if len(image_parts) < 2: | |
logger.error( | |
"Encountered invalid image: %s", image_parts | |
) | |
raise WebSocketException( | |
code=fastapi.status.WS_1008_POLICY_VIOLATION, | |
reason=f"Invalid image: {image_parts}", | |
) | |
image_data = image_parts[1] | |
base64_image = base64.b64decode(image_data) | |
image_file = io.BytesIO(base64_image) | |
parsed_prompt.append(PIL.Image.open(image_file)) | |
case _: | |
raise ValueError("Unknown content type") | |
logger.info("Prompt: %s", parsed_prompt) | |
partial_outputs = [] | |
final_contents: list[Content] = [] | |
match generate_message.message_type: | |
case WSMessageType.GENERATE_TEXT: | |
output_generator = generator.generate_text_streaming | |
case WSMessageType.GENERATE_IMAGE: | |
output_generator = generator.generate_image_streaming | |
case WSMessageType.GENERATE_MULTIMODAL: | |
output_generator = generator.generate_multimodal_streaming | |
case _: | |
raise WebSocketException( | |
code=fastapi.status.WS_1008_POLICY_VIOLATION, | |
reason="Unknown message type", | |
) | |
logger.info( | |
"Acquiring lock for client %s generation with options: %s", | |
client_id, | |
generate_message.options, | |
) | |
option_args = generate_message.options.dict() | |
debug_info = {"hostname": hostname} | |
del option_args["message_type"] | |
output_generator = partial( | |
output_generator, | |
**option_args, | |
debug=debug_info, | |
) | |
if redis_lock is not None: | |
await async_acquire_lock( | |
websocket=websocket, | |
lock=redis_lock, | |
hostname=hostname, | |
counter=queue_counter, | |
) | |
await redis_client.set("has_lock", client_id) | |
logger.info( | |
"Starting locked generation for client %s with options: %s", | |
client_id, | |
generate_message.options, | |
) | |
try: | |
async for output_token in output_generator(parsed_prompt): | |
if isinstance(output_token, str): | |
content_type = ContentType.TEXT | |
content = output_token | |
message_type = WSMessageType.PARTIAL_OUTPUT | |
options = NoOptionsForPartial() | |
partial_outputs.extend(output_token) | |
elif isinstance(output_token, StreamingImage): | |
content_type = ContentType.IMAGE | |
image = output_token.image | |
img_io = io.BytesIO() | |
image.save(img_io, format="png") | |
content = ( | |
"data:image/png;base64," | |
+ base64.b64encode(img_io.getvalue()).decode() | |
) | |
if output_token.final: | |
message_type = WSMessageType.FULL_OUTPUT | |
options = NoOptionsForFull() | |
else: | |
message_type = WSMessageType.PARTIAL_OUTPUT | |
options = NoOptionsForPartial() | |
if output_token.final: | |
partial_outputs.append(output_token.image) | |
else: | |
raise ValueError(f"Invalid output_token: {output_token}") | |
message_content = Content( | |
content_type=content_type, content=content | |
) | |
match content_type: | |
case ContentType.TEXT: | |
final_contents.append(message_content) | |
case ContentType.IMAGE: | |
if message_type == WSMessageType.FULL_OUTPUT: | |
final_contents.append(message_content) | |
case _: | |
pass | |
message = WSMultimodalMessage( | |
message_type=message_type, | |
content=[message_content], | |
options=options, | |
debug_info=debug_info, | |
).dict() | |
await websocket.send_json(message) | |
finally: | |
if redis_lock is not None: | |
logger.info( | |
"Attempting release of lock for client %s generation with options: %s", | |
client_id, | |
generate_message.options, | |
) | |
owned = await redis_lock.owned() | |
if owned: | |
await redis_client.set("has_lock", "") | |
try: | |
await redis_lock.release() | |
except async_redis.lock.LockError: | |
pass | |
logger.info( | |
"Released lock for client %s generation with options: %s", | |
client_id, | |
generate_message.options, | |
) | |
await websocket.send_json( | |
WSMultimodalMessage( | |
message_type=WSMessageType.COMPLETE, | |
content=final_contents, | |
options=NoOptionsForComplete(), | |
debug_info=debug_info, | |
).dict() | |
) | |
except WebSocketDisconnect: | |
logger.info("Client disconnected %s", client_id) | |
except ConnectionClosedError: | |
logger.info("Client forced a close %s", client_id) | |
except ConnectionClosedOK: | |
logger.info("Connection closed ok %s", client_id) | |
finally: | |
if redis_lock is not None: | |
logger.info("Checking for client holding lock: %s", client_id) | |
owned = await redis_lock.owned() | |
if owned: | |
try: | |
logger.info("Attempted to release owned lock: %s", client_id) | |
await redis_lock.release() | |
except async_redis.lock.LockError: | |
pass | |
await redis_client.set("has_lock", "") | |
return app | |
def serve( | |
model: AbstractMultimodalGenerator, | |
host: str, | |
port: int, | |
debug: bool = True, | |
redis_port: int | None = None, | |
) -> None: | |
app = web_app(model, debug=debug, redis_port=redis_port) | |
# TODO: convert this to a subprocess call so enable more | |
# uvicorn features like multiple workers | |
uvicorn.run(app, host=host, port=port) | |