from typing import List, Optional, Union, Literal from fastapi import FastAPI, Body from pydantic import BaseModel from transformers import AutoProcessor, AutoModelForVision2Seq from PIL import Image as PILImage import torch import base64 import io import os from starlette.responses import FileResponse app = FastAPI(docs_url="/docs", redoc_url="/redoc", openapi_url="/openapi.json") # Initialize model and processor MODEL_NAME = "bytedance-research/UI-TARS-7B-DPO" device = "cuda" if torch.cuda.is_available() else "cpu" try: model = AutoModelForVision2Seq.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(device) # Use float16 with low CPU memory usage except RuntimeError as e: if "CUDA out of memory" in str(e): print("Warning: Loading model in float16 failed due to insufficient memory. Falling back to CPU and float32.") device = "cpu" # Switch to CPU model = AutoModelForVision2Seq.from_pretrained(MODEL_NAME, low_cpu_mem_usage=True).to(device) # Load in float32 on CPU with low CPU mem usage import gc gc.collect() torch.cuda.empty_cache() else: raise e processor = AutoProcessor.from_pretrained(MODEL_NAME) # Pydantic models class ImageUrl(BaseModel): url: str class Image(BaseModel): type: Literal["image_url"] = "image_url" image_url: ImageUrl class Content(BaseModel): type: Literal["text", "image_url"] text: Optional[str] = None image_url: Optional[ImageUrl] = None class Message(BaseModel): role: Literal["user", "system", "assistant"] content: Union[str, List[Content]] class ChatCompletionRequest(BaseModel): messages: List[Message] max_tokens: Optional[int] = 128 @app.post("/chat/completions") async def chat_completion(request: ChatCompletionRequest = Body(...)): # Extract first message content messages = request.messages max_tokens = request.max_tokens first_message = messages[0] image_url = None text_content = None if isinstance(first_message.content, str): text_content = first_message.content else: for content_item in first_message.content: if content_item.type == "image_url": image_url = content_item.image_url.url elif content_item.type == "text": text_content = content_item.text # Process image if provided pil_image = None if image_url: try: if image_url.startswith("data:image"): header, encoded = image_url.split(",", 1) image_data = base64.b64decode(encoded) pil_image = PILImage.open(io.BytesIO(image_data)).convert("RGB") else: print("Image URL provided, but base64 expected.") except Exception as e: print(f"Error processing image: {e}") raise e # Generate response try: inputs = processor(text=text_content, images=pil_image, return_tensors="pt").to(device) outputs = model.generate(**inputs, max_new_tokens=max_tokens) response = processor.batch_decode(outputs, skip_special_tokens=True)[0] except Exception as e: print(f"Error during model inference: {e}") raise e return { "choices": [{ "message": { "role": "assistant", "content": response } }] } @app.get("/") def index(): return FileResponse("static/index.html") @app.on_event("startup") def startup_event(): # In Hugging Face Spaces, the application is usually accessible at https://.hf.space # Here we assume the space name is 'api-UI-TARS-7B-DPO' public_url = "https://api-UI-TARS-7B-DPO.hf.space" print(f"Public URL: {public_url}")