Spaces:
Runtime error
Runtime error
import asyncio | |
import logging | |
import tempfile | |
from fastapi import FastAPI, Request, BackgroundTasks | |
from fastapi.responses import HTMLResponse | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.templating import Jinja2Templates | |
from contextlib import asynccontextmanager | |
from datasets import Dataset, load_dataset | |
from models import chunk_config, embed_config, env_config, WebhookPayload | |
from chunking_utils import Chunker, chunk_generator | |
from embed_utils import wake_up_endpoint, embed_wrapper | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app_state = {} | |
async def lifespan(app: FastAPI): | |
app_state["seen_Sha"] = set() | |
yield | |
app_state.clear() | |
app = FastAPI(lifespan=lifespan) | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
templates = Jinja2Templates(directory="templates") | |
async def root(request: Request): | |
return templates.TemplateResponse(request=request, name="index.html") | |
async def post_webhook( | |
payload: WebhookPayload, | |
task_queue: BackgroundTasks | |
): | |
if not ( | |
payload.event.action == "update" | |
and payload.event.scope.startswith("repo.content") | |
and payload.repo.type == "dataset" | |
# webhook posts multiple requests with the same update, this addresses that | |
and payload.repo.headSha not in app_state["seen_Sha"] | |
): | |
logger.info("Update detected, no action taken") | |
return {"processed": False} | |
app_state["seen_Sha"].add(payload.repo.headSha) | |
task_queue.add_task(chunk_and_embed, input_ds_name=payload.repo.name) | |
return {"processed": True} | |
def chunk(ds_name): | |
logger.info("Update detected, chunking is scheduled") | |
input_ds = load_dataset(ds_name, split="+".join(env_config.input_splits)) | |
chunker = Chunker( | |
strategy=chunk_config.strategy, | |
split_seq=chunk_config.split_seq, | |
chunk_len=chunk_config.chunk_len | |
) | |
tmp_file = tempfile.NamedTemporaryFile(mode="a", suffix=".jsonl") | |
dataset = Dataset.from_generator( | |
chunk_generator, | |
gen_kwargs={ | |
"input_dataset": input_ds, | |
"chunker": chunker, | |
"tmp_file": tmp_file | |
} | |
) | |
dataset.push_to_hub( | |
env_config.chunked_ds_name, | |
private=chunk_config.private, | |
token=env_config.hf_token | |
) | |
logger.info("Done chunking") | |
return tmp_file | |
def embed(chunked_file): | |
logger.info("Update detected, embedding is scheduled") | |
wake_up_endpoint() | |
chunked_ds = Dataset.from_json(chunked_file.name) | |
with tempfile.NamedTemporaryFile(mode="a", suffix=".jsonl") as temp_file: | |
asyncio.run(embed_wrapper(chunked_ds, temp_file)) | |
emb_ds = Dataset.from_json(temp_file.name) | |
emb_ds.push_to_hub( | |
env_config.embed_ds_name, | |
private=embed_config.private, | |
token=env_config.hf_token | |
) | |
chunked_file.close() | |
logger.info("Done embedding") | |
return | |
def chunk_and_embed(input_ds_name): | |
chunked_tmp_file = chunk(input_ds_name) | |
embed(chunked_tmp_file) | |
return {"processed": True} | |
# For debugging | |
# import uvicorn | |
# if __name__ == "__main__": | |
# uvicorn.run(app, host="0.0.0.0", port=7860) | |