from fastapi import FastAPI, WebSocket, Request from fastapi.responses import HTMLResponse from fastapi.templating import Jinja2Templates from fastapi.staticfiles import StaticFiles from huggingface_hub import InferenceClient import os import json import asyncio app = FastAPI() # Mount static files directory app.mount("/static", StaticFiles(directory="static"), name="static") # Setup Jinja2 templates templates = Jinja2Templates(directory="templates") # Initialize the Hugging Face Inference Client client = InferenceClient() async def generate_stream_response(prompt_template: str, **kwargs): """ Generate a streaming response using Hugging Face Inference Client Args: prompt_template (str): The prompt template to use **kwargs: Dynamic arguments to format the prompt Yields: str: Streamed content chunks """ # Construct the prompt (you'll need to set up environment variables or a prompt mapping) prompt = os.getenv(prompt_template).format(**kwargs) # Prepare messages for the model messages = [ {"role": "user", "content": prompt} ] try: # Create a stream for the chat completion stream = client.chat.completions.create( model="Qwen/Qwen2.5-Math-1.5B-Instruct", messages=messages, temperature=0.7, max_tokens=1024, top_p=0.8, stream=True ) # Stream the generated content for chunk in stream: if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content: yield chunk.choices[0].delta.content except Exception as e: yield f"Error occurred: {str(e)}" @app.websocket("/ws/{endpoint}") async def websocket_endpoint(websocket: WebSocket, endpoint: str): """ WebSocket endpoint for streaming responses Args: websocket (WebSocket): The WebSocket connection endpoint (str): The specific endpoint/task to process """ await websocket.accept() try: # Receive the initial message with parameters data = await websocket.receive_json() # Map the endpoint to the appropriate prompt template endpoint_prompt_map = { "solve": "PROMPT_SOLVE", "hint": "PROMPT_HINT", "verify": "PROMPT_VERIFY", "generate": "PROMPT_GENERATE", "explain": "PROMPT_EXPLAIN" } # Get the appropriate prompt template prompt_template = endpoint_prompt_map.get(endpoint) if not prompt_template: await websocket.send_json({"error": "Invalid endpoint"}) return # Stream the response full_response = "" async for chunk in generate_stream_response(prompt_template, **data): full_response += chunk await websocket.send_json({"chunk": chunk}) # Send a final message to indicate streaming is complete await websocket.send_json({"complete": True, "full_response": full_response}) except Exception as e: await websocket.send_json({"error": str(e)}) finally: await websocket.close() # Existing routes remain the same as in the previous implementation @app.get("/", response_class=HTMLResponse) async def home(request: Request): return HTMLResponse(open("static/index.html").read())