webhook-space / main.py
plaggy's picture
refactor
8b7a023
raw
history blame
3.32 kB
import asyncio
import logging
import tempfile
from fastapi import FastAPI, Request, BackgroundTasks
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from contextlib import asynccontextmanager
from datasets import Dataset, load_dataset
from models import chunk_config, embed_config, env_config, WebhookPayload
from chunking_utils import Chunker, chunk_generator
from embed_utils import wake_up_endpoint, embed_wrapper
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app_state = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
app_state["seen_Sha"] = set()
yield
app_state.clear()
app = FastAPI(lifespan=lifespan)
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
@app.get("/", response_class=HTMLResponse)
async def root(request: Request):
return templates.TemplateResponse(request=request, name="index.html")
@app.post("/webhook")
async def post_webhook(
payload: WebhookPayload,
task_queue: BackgroundTasks
):
if not (
payload.event.action == "update"
and payload.event.scope.startswith("repo.content")
and payload.repo.type == "dataset"
# webhook posts multiple requests with the same update, this addresses that
and payload.repo.headSha not in app_state["seen_Sha"]
):
logger.info("Update detected, no action taken")
return {"processed": False}
app_state["seen_Sha"].add(payload.repo.headSha)
task_queue.add_task(chunk_and_embed, input_ds_name=payload.repo.name)
return {"processed": True}
def chunk(ds_name):
logger.info("Update detected, chunking is scheduled")
input_ds = load_dataset(ds_name, split="+".join(env_config.input_splits))
chunker = Chunker(
strategy=chunk_config.strategy,
split_seq=chunk_config.split_seq,
chunk_len=chunk_config.chunk_len
)
tmp_file = tempfile.NamedTemporaryFile(mode="a", suffix=".jsonl")
dataset = Dataset.from_generator(
chunk_generator,
gen_kwargs={
"input_dataset": input_ds,
"chunker": chunker,
"tmp_file": tmp_file
}
)
dataset.push_to_hub(
env_config.chunked_ds_name,
private=chunk_config.private,
token=env_config.hf_token
)
logger.info("Done chunking")
return tmp_file
def embed(chunked_file):
logger.info("Update detected, embedding is scheduled")
wake_up_endpoint()
chunked_ds = Dataset.from_json(chunked_file.name)
with tempfile.NamedTemporaryFile(mode="a", suffix=".jsonl") as temp_file:
asyncio.run(embed_wrapper(chunked_ds, temp_file))
emb_ds = Dataset.from_json(temp_file.name)
emb_ds.push_to_hub(
env_config.embed_ds_name,
private=embed_config.private,
token=env_config.hf_token
)
chunked_file.close()
logger.info("Done embedding")
return
def chunk_and_embed(input_ds_name):
chunked_tmp_file = chunk(input_ds_name)
embed(chunked_tmp_file)
return {"processed": True}
# For debugging
# import uvicorn
# if __name__ == "__main__":
# uvicorn.run(app, host="0.0.0.0", port=7860)