Spaces:
Runtime error
Runtime error
File size: 3,207 Bytes
97fdba5 187981b 8b7a023 97fdba5 8b7a023 97fdba5 8b7a023 b70008c 8b7a023 97fdba5 f05487c 8b7a023 97fdba5 187981b 97fdba5 60d1cfd 187981b 97fdba5 b70008c 8b7a023 97fdba5 8b7a023 97fdba5 8b7a023 97fdba5 8b7a023 97fdba5 8b7a023 97fdba5 8b7a023 97fdba5 8b7a023 97fdba5 8b7a023 97fdba5 8b7a023 97fdba5 8b7a023 97fdba5 8b7a023 97fdba5 8b7a023 97fdba5 8b7a023 97fdba5 8b7a023 97fdba5 8b7a023 97fdba5 8b7a023 3dca465 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import asyncio
import logging
import tempfile
from fastapi import FastAPI, Request, BackgroundTasks
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from contextlib import asynccontextmanager
from datasets import Dataset, load_dataset
from models import chunk_config, embed_config, env_config, WebhookPayload
from chunking_utils import Chunker, chunk_generator
from embed_utils import wake_up_endpoint, embed_wrapper
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app_state = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
app_state["seen_Sha"] = set()
yield
app_state.clear()
app = FastAPI(lifespan=lifespan)
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
@app.get("/", response_class=HTMLResponse)
async def root(request: Request):
return templates.TemplateResponse(request=request, name="index.html")
@app.post("/webhook")
async def post_webhook(
payload: WebhookPayload,
task_queue: BackgroundTasks
):
if not (
payload.event.action == "update"
and payload.event.scope.startswith("repo.content")
and payload.repo.type == "dataset"
# webhook posts multiple requests with the same update, this addresses that
and payload.repo.headSha not in app_state["seen_Sha"]
):
logger.info("Update detected, no action taken")
return {"processed": False}
app_state["seen_Sha"].add(payload.repo.headSha)
task_queue.add_task(chunk_and_embed, input_ds_name=payload.repo.name)
return {"processed": True}
def chunk(ds_name):
logger.info("Update detected, chunking is scheduled")
input_ds = load_dataset(ds_name, split="+".join(env_config.input_splits))
chunker = Chunker(
strategy=chunk_config.strategy,
split_seq=chunk_config.split_seq,
chunk_len=chunk_config.chunk_len
)
tmp_file = tempfile.NamedTemporaryFile(mode="a", suffix=".jsonl")
dataset = Dataset.from_generator(
chunk_generator,
gen_kwargs={
"input_dataset": input_ds,
"chunker": chunker,
"tmp_file": tmp_file
}
)
dataset.push_to_hub(
env_config.chunked_ds_name,
private=chunk_config.private,
token=env_config.hf_token
)
logger.info("Done chunking")
return tmp_file
def embed(chunked_file):
logger.info("Update detected, embedding is scheduled")
wake_up_endpoint()
chunked_ds = Dataset.from_json(chunked_file.name)
with tempfile.NamedTemporaryFile(mode="a", suffix=".jsonl") as temp_file:
asyncio.run(embed_wrapper(chunked_ds, temp_file))
emb_ds = Dataset.from_json(temp_file.name)
emb_ds.push_to_hub(
env_config.embed_ds_name,
private=embed_config.private,
token=env_config.hf_token
)
chunked_file.close()
logger.info("Done embedding")
return
def chunk_and_embed(input_ds_name):
chunked_tmp_file = chunk(input_ds_name)
embed(chunked_tmp_file)
return {"processed": True} |