import os from fastapi import FastAPI, Request from fastapi.responses import HTMLResponse, JSONResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel from huggingface_hub import InferenceClient import re # Initialize FastAPI app app = FastAPI() # Serve static files for assets app.mount("/static", StaticFiles(directory="static"), name="static") # Initialize Hugging Face Inference Client client = InferenceClient() # Pydantic model for API input class InfographicRequest(BaseModel): description: str # Load prompt template from environment variable PROMPT_TEMPLATE = os.getenv("PROMPT_TEMPLATE") async def extract_code_blocks(markdown_text): """ Extracts code blocks from the given Markdown text. Args: markdown_text (str): The Markdown content as a string. Returns: list: A list of code blocks extracted from the Markdown. """ # Regex to match code blocks (fenced with triple backticks) code_block_pattern = re.compile(r'```.*?\n(.*?)```', re.DOTALL) # Find all code blocks code_blocks = code_block_pattern.findall(markdown_text) return code_blocks # Route to serve the HTML template @app.get("/", response_class=HTMLResponse) async def serve_frontend(): return HTMLResponse(open("static/infographic_gen.html").read()) # Route to handle infographic generation @app.post("/generate") async def generate_infographic(request: InfographicRequest): description = request.description prompt = PROMPT_TEMPLATE.format(description=description) try: messages = [{"role": "user", "content": prompt}] stream = client.chat.completions.create( model="Qwen/Qwen2.5-Coder-32B-Instruct", messages=messages, temperature=0.5, max_tokens=32000, top_p=0.7, stream=True, ) generated_text = "" for chunk in stream: generated_text += chunk.choices[0].delta.content print(generated_text) code_blocks= await extract_code_blocks(generated_text) if code_blocks: return JSONResponse(content={"html": code_blocks[0]}) else: return JSONResponse(content={"error": "No generation"},status_code=500) except Exception as e: return JSONResponse(content={"error": str(e)}, status_code=500)