import asyncio import logging import tempfile from fastapi import FastAPI, Request, BackgroundTasks from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from contextlib import asynccontextmanager from datasets import Dataset, load_dataset from models import chunk_config, embed_config, env_config, WebhookPayload from chunking_utils import Chunker, chunk_generator from embed_utils import wake_up_endpoint, embed_wrapper logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app_state = {} @asynccontextmanager async def lifespan(app: FastAPI): app_state["seen_Sha"] = set() yield app_state.clear() app = FastAPI(lifespan=lifespan) app.mount("/static", StaticFiles(directory="static"), name="static") templates = Jinja2Templates(directory="templates") @app.get("/", response_class=HTMLResponse) async def root(request: Request): return templates.TemplateResponse(request=request, name="index.html") @app.post("/webhook") async def post_webhook( payload: WebhookPayload, task_queue: BackgroundTasks ): if not ( payload.event.action == "update" and payload.event.scope.startswith("repo.content") and payload.repo.type == "dataset" # webhook posts multiple requests with the same update, this addresses that and payload.repo.headSha not in app_state["seen_Sha"] ): logger.info("Update detected, no action taken") return {"processed": False} app_state["seen_Sha"].add(payload.repo.headSha) task_queue.add_task(chunk_and_embed, input_ds_name=payload.repo.name) return {"processed": True} def chunk(ds_name): logger.info("Update detected, chunking is scheduled") input_ds = load_dataset(ds_name, split="+".join(env_config.input_splits)) chunker = Chunker( strategy=chunk_config.strategy, split_seq=chunk_config.split_seq, chunk_len=chunk_config.chunk_len ) tmp_file = tempfile.NamedTemporaryFile(mode="a", suffix=".jsonl") dataset = Dataset.from_generator( chunk_generator, gen_kwargs={ "input_dataset": input_ds, "chunker": chunker, "tmp_file": tmp_file } ) dataset.push_to_hub( env_config.chunked_ds_name, private=chunk_config.private, token=env_config.hf_token ) logger.info("Done chunking") return tmp_file def embed(chunked_file): logger.info("Update detected, embedding is scheduled") wake_up_endpoint() chunked_ds = Dataset.from_json(chunked_file.name) with tempfile.NamedTemporaryFile(mode="a", suffix=".jsonl") as temp_file: asyncio.run(embed_wrapper(chunked_ds, temp_file)) emb_ds = Dataset.from_json(temp_file.name) emb_ds.push_to_hub( env_config.embed_ds_name, private=embed_config.private, token=env_config.hf_token ) chunked_file.close() logger.info("Done embedding") return def chunk_and_embed(input_ds_name): chunked_tmp_file = chunk(input_ds_name) embed(chunked_tmp_file) return {"processed": True} # For debugging # import uvicorn # if __name__ == "__main__": # uvicorn.run(app, host="0.0.0.0", port=7860)