File size: 3,207 Bytes
97fdba5
 
 
 
187981b
 
 
 
8b7a023
97fdba5
 
8b7a023
 
 
97fdba5
 
 
 
8b7a023
 
b70008c
8b7a023
 
 
 
 
97fdba5
f05487c
8b7a023
97fdba5
187981b
 
97fdba5
60d1cfd
187981b
 
 
97fdba5
 
 
 
 
 
 
 
b70008c
 
 
 
8b7a023
97fdba5
 
 
 
8b7a023
 
97fdba5
 
 
 
8b7a023
97fdba5
8b7a023
97fdba5
 
 
 
 
8b7a023
97fdba5
 
 
 
8b7a023
 
97fdba5
 
 
 
8b7a023
97fdba5
8b7a023
97fdba5
 
 
8b7a023
97fdba5
 
8b7a023
97fdba5
8b7a023
 
97fdba5
8b7a023
97fdba5
8b7a023
 
 
97fdba5
8b7a023
97fdba5
 
8b7a023
97fdba5
8b7a023
 
 
 
 
 
 
3dca465
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import asyncio
import logging
import tempfile

from fastapi import FastAPI, Request, BackgroundTasks
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from contextlib import asynccontextmanager
from datasets import Dataset, load_dataset

from models import chunk_config, embed_config, env_config, WebhookPayload
from chunking_utils import Chunker, chunk_generator
from embed_utils import wake_up_endpoint, embed_wrapper

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app_state = {}


@asynccontextmanager
async def lifespan(app: FastAPI):
    app_state["seen_Sha"] = set()
    yield
    app_state.clear()


app = FastAPI(lifespan=lifespan)

app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")


@app.get("/", response_class=HTMLResponse)
async def root(request: Request):
    return templates.TemplateResponse(request=request, name="index.html")


@app.post("/webhook")
async def post_webhook(
        payload: WebhookPayload,
        task_queue: BackgroundTasks
):
    if not (
        payload.event.action == "update"
        and payload.event.scope.startswith("repo.content")
        and payload.repo.type == "dataset"
        # webhook posts multiple requests with the same update, this addresses that
        and payload.repo.headSha not in app_state["seen_Sha"]
    ):
        logger.info("Update detected, no action taken")
        return {"processed": False}

    app_state["seen_Sha"].add(payload.repo.headSha)
    task_queue.add_task(chunk_and_embed, input_ds_name=payload.repo.name)

    return {"processed": True}


def chunk(ds_name):
    logger.info("Update detected, chunking is scheduled")
    input_ds = load_dataset(ds_name, split="+".join(env_config.input_splits))
    chunker = Chunker(
        strategy=chunk_config.strategy,
        split_seq=chunk_config.split_seq,
        chunk_len=chunk_config.chunk_len
    )
    tmp_file = tempfile.NamedTemporaryFile(mode="a", suffix=".jsonl")
    dataset = Dataset.from_generator(
        chunk_generator,
        gen_kwargs={
            "input_dataset": input_ds,
            "chunker": chunker,
            "tmp_file": tmp_file
        }
    )

    dataset.push_to_hub(
        env_config.chunked_ds_name,
        private=chunk_config.private,
        token=env_config.hf_token
    )

    logger.info("Done chunking")
    return tmp_file


def embed(chunked_file):
    logger.info("Update detected, embedding is scheduled")
    wake_up_endpoint()
    chunked_ds = Dataset.from_json(chunked_file.name)
    with tempfile.NamedTemporaryFile(mode="a", suffix=".jsonl") as temp_file:
        asyncio.run(embed_wrapper(chunked_ds, temp_file))

        emb_ds = Dataset.from_json(temp_file.name)
        emb_ds.push_to_hub(
            env_config.embed_ds_name,
            private=embed_config.private,
            token=env_config.hf_token
        )

    chunked_file.close()
    logger.info("Done embedding")
    return


def chunk_and_embed(input_ds_name):
    chunked_tmp_file = chunk(input_ds_name)
    embed(chunked_tmp_file)

    return {"processed": True}