yasserrmd's picture
Update app.py
11511b6 verified
raw
history blame
2.39 kB
import os
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from huggingface_hub import InferenceClient
import re
# Initialize FastAPI app
app = FastAPI()
# Serve static files for assets
app.mount("/static", StaticFiles(directory="static"), name="static")
# Initialize Hugging Face Inference Client
client = InferenceClient()
# Pydantic model for API input
class InfographicRequest(BaseModel):
description: str
# Load prompt template from environment variable
PROMPT_TEMPLATE = os.getenv("PROMPT_TEMPLATE")
async def extract_code_blocks(markdown_text):
"""
Extracts code blocks from the given Markdown text.
Args:
markdown_text (str): The Markdown content as a string.
Returns:
list: A list of code blocks extracted from the Markdown.
"""
# Regex to match code blocks (fenced with triple backticks)
code_block_pattern = re.compile(r'```.*?\n(.*?)```', re.DOTALL)
# Find all code blocks
code_blocks = code_block_pattern.findall(markdown_text)
return code_blocks
# Route to serve the HTML template
@app.get("/", response_class=HTMLResponse)
async def serve_frontend():
with open("infographic_gen.html", "r") as file:
return HTMLResponse(content=file.read())
# Route to handle infographic generation
@app.post("/generate")
async def generate_infographic(request: InfographicRequest):
description = request.description
prompt = PROMPT_TEMPLATE.format(description=description)
try:
messages = [{"role": "user", "content": prompt}]
stream = client.chat.completions.create(
model="Qwen/Qwen2.5-Coder-32B-Instruct",
messages=messages,
temperature=0.5,
max_tokens=1024,
top_p=0.7,
stream=True,
)
generated_text = ""
for chunk in stream:
generated_text += chunk.choices[0].delta.content
code_blocks=extract_code_blocks(generated_text)
if code_blocks:
return JSONResponse(content={"html": code_blocks[0]})
else:
return JSONResponse(content={"error": "No generation"},status_code=500)
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)