plaggy commited on
Commit
b70008c
1 Parent(s): cfa6dc8

minor refactor

Browse files
Files changed (6) hide show
  1. Dockerfile +0 -1
  2. chunk_config.json +1 -3
  3. embed_config.json +0 -2
  4. home.html +1 -1
  5. src/main.py +29 -20
  6. src/models.py +2 -10
Dockerfile CHANGED
@@ -10,7 +10,6 @@ WORKDIR $HOME/app
10
 
11
  COPY --chown=user requirements.txt requirements.txt
12
  RUN pip install --no-cache-dir --upgrade -r requirements.txt
13
- RUN python -m spacy download en_core_web_sm
14
 
15
  COPY --chown=user . .
16
 
 
10
 
11
  COPY --chown=user requirements.txt requirements.txt
12
  RUN pip install --no-cache-dir --upgrade -r requirements.txt
 
13
 
14
  COPY --chown=user . .
15
 
chunk_config.json CHANGED
@@ -1,9 +1,7 @@
1
  {
2
- "input_dataset": "sergeipetrov/transformers-diffusers-docs-raw",
3
  "input_splits": ["train"],
4
  "input_text_col": "text",
5
- "output_dataset": "sergeipetrov/transformers-diffusers-docs-chunked",
6
- "strategy": "spacy",
7
  "split_seq": "\n\n",
8
  "chunk_len": 512,
9
  "private": "false"
 
1
  {
 
2
  "input_splits": ["train"],
3
  "input_text_col": "text",
4
+ "strategy": "recursive",
 
5
  "split_seq": "\n\n",
6
  "chunk_len": 512,
7
  "private": "false"
embed_config.json CHANGED
@@ -1,8 +1,6 @@
1
  {
2
- "input_dataset": "sergeipetrov/transformers-diffusers-docs-chunked",
3
  "input_splits": ["train"],
4
  "input_text_col": "text",
5
- "output_dataset": "sergeipetrov/transformers-diffusers-docs-embed",
6
  "private": "false",
7
  "semaphore_bound": 5
8
  }
 
1
  {
 
2
  "input_splits": ["train"],
3
  "input_text_col": "text",
 
4
  "private": "false",
5
  "semaphore_bound": 5
6
  }
home.html CHANGED
@@ -4,7 +4,7 @@
4
  <meta charset="utf-8" />
5
  <meta name="viewport" content="width=device-width" />
6
  <title>Auto chunking and embedding</title>
7
- <link rel="stylesheet" href="./style.css" />
8
  </head>
9
  <body>
10
  <div class="card">
 
4
  <meta charset="utf-8" />
5
  <meta name="viewport" content="width=device-width" />
6
  <title>Auto chunking and embedding</title>
7
+ <link rel="stylesheet" href="style.css" />
8
  </head>
9
  <body>
10
  <div class="card">
src/main.py CHANGED
@@ -11,7 +11,7 @@ from fastapi import FastAPI, BackgroundTasks
11
  from fastapi.responses import FileResponse
12
 
13
  from aiohttp import ClientSession
14
- from langchain.text_splitter import SpacyTextSplitter
15
  from datasets import Dataset, load_dataset
16
  from tqdm import tqdm
17
  from tqdm.asyncio import tqdm_asyncio
@@ -22,7 +22,12 @@ logging.basicConfig(level=logging.INFO)
22
  logger = logging.getLogger(__name__)
23
 
24
  HF_TOKEN = os.getenv("HF_TOKEN")
 
25
  TEI_URL = os.getenv("TEI_URL")
 
 
 
 
26
 
27
  app = FastAPI()
28
  app.state.last_Sha = None
@@ -39,19 +44,19 @@ async def post_webhook(
39
  task_queue: BackgroundTasks
40
  ):
41
  if not (
42
- payload.event.action == "update"
43
- and payload.event.scope.startswith("repo.content")
44
- and payload.repo.name == chunk_config.input_dataset
45
- and payload.repo.type == "dataset"
46
- and (not app.state.last_Sha or app.state.last_Sha != payload.repo.headSha)
47
  ):
48
  # no-op
49
  logger.info("Update detected, no action taken")
50
  return {"processed": False}
51
 
52
  app.state.last_Sha = payload.repo.headSha
53
- task_queue.add_task(chunk_dataset)
54
- task_queue.add_task(embed_dataset)
55
 
56
  return {"processed": True}
57
 
@@ -61,11 +66,14 @@ CHUNKING
61
  """
62
 
63
  class Chunker:
64
- def __init__(self, strategy, split_seq, chunk_len):
65
  self.split_seq = split_seq
66
  self.chunk_len = chunk_len
67
- if strategy == "spacy":
68
- self.split = SpacyTextSplitter().split_text
 
 
 
69
  if strategy == "sequence":
70
  self.split = self.seq_splitter
71
  if strategy == "constant":
@@ -83,15 +91,15 @@ class Chunker:
83
 
84
  def chunk_generator(input_dataset, chunker):
85
  for i in tqdm(range(len(input_dataset))):
86
- chunks = chunker.split(input_dataset[i][chunk_config.input_text_col])
87
  for chunk in chunks:
88
  if chunk:
89
- yield {chunk_config.input_text_col: chunk}
90
 
91
 
92
- def chunk_dataset():
93
  logger.info("Update detected, chunking is scheduled")
94
- input_ds = load_dataset(chunk_config.input_dataset, split="+".join(chunk_config.input_splits))
95
  chunker = Chunker(
96
  strategy=chunk_config.strategy,
97
  split_seq=chunk_config.split_seq,
@@ -140,15 +148,15 @@ async def embed_sent(sentence, semaphore, tmp_file):
140
  result = await resp.json()
141
 
142
  tmp_file.write(
143
- json.dumps({"vector": result[0], chunk_config.input_text_col: sentence}) + "\n"
144
  )
145
 
146
 
147
  async def embed(input_ds, temp_file):
148
  semaphore = asyncio.BoundedSemaphore(embed_config.semaphore_bound)
149
  jobs = [
150
- asyncio.create_task(embed_sent(row[chunk_config.input_text_col], semaphore, temp_file))
151
- for row in input_ds if row[chunk_config.input_text_col].strip()
152
  ]
153
  logger.info(f"num chunks to embed: {len(jobs)}")
154
 
@@ -158,6 +166,7 @@ async def embed(input_ds, temp_file):
158
 
159
 
160
  def wake_up_endpoint(url):
 
161
  n_loop = 0
162
  while requests.get(
163
  url=url,
@@ -170,10 +179,10 @@ def wake_up_endpoint(url):
170
  logger.info("TEI endpoint is up")
171
 
172
 
173
- def embed_dataset():
174
  logger.info("Update detected, embedding is scheduled")
175
  wake_up_endpoint(TEI_URL)
176
- input_ds = load_dataset(embed_config.input_dataset, split="+".join(chunk_config.input_splits))
177
  with tempfile.NamedTemporaryFile(mode="a", suffix=".jsonl") as temp_file:
178
  asyncio.run(embed(input_ds, temp_file))
179
 
 
11
  from fastapi.responses import FileResponse
12
 
13
  from aiohttp import ClientSession
14
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
15
  from datasets import Dataset, load_dataset
16
  from tqdm import tqdm
17
  from tqdm.asyncio import tqdm_asyncio
 
22
  logger = logging.getLogger(__name__)
23
 
24
  HF_TOKEN = os.getenv("HF_TOKEN")
25
+
26
  TEI_URL = os.getenv("TEI_URL")
27
+ CHUNKED_DS_NAME = os.getenv("CHUNKED_DS_NAME")
28
+ EMBED_DS_NAME = os.getenv("EMBED_DS_NAME")
29
+ INPUT_SPLITS = os.getenv("INPUT_SPLITS")
30
+ INPUT_TEXT_COL = os.getenv("INPUT_TEXT_COL")
31
 
32
  app = FastAPI()
33
  app.state.last_Sha = None
 
44
  task_queue: BackgroundTasks
45
  ):
46
  if not (
47
+ payload.event.action == "update"
48
+ and payload.event.scope.startswith("repo.content")
49
+ and payload.repo.type == "dataset"
50
+ # webhook posts multiple requests with the same update, this addresses that
51
+ and (not app.state.last_Sha or app.state.last_Sha != payload.repo.headSha)
52
  ):
53
  # no-op
54
  logger.info("Update detected, no action taken")
55
  return {"processed": False}
56
 
57
  app.state.last_Sha = payload.repo.headSha
58
+ task_queue.add_task(chunk_dataset, ds_name=payload.repo.name)
59
+ task_queue.add_task(embed_dataset, ds_name=CHUNKED_DS_NAME)
60
 
61
  return {"processed": True}
62
 
 
66
  """
67
 
68
  class Chunker:
69
+ def __init__(self, strategy, split_seq=".", chunk_len=512):
70
  self.split_seq = split_seq
71
  self.chunk_len = chunk_len
72
+ if strategy == "recursive":
73
+ self.split = RecursiveCharacterTextSplitter(
74
+ chunk_size=chunk_len,
75
+ separators=[split_seq]
76
+ ).split_text
77
  if strategy == "sequence":
78
  self.split = self.seq_splitter
79
  if strategy == "constant":
 
91
 
92
  def chunk_generator(input_dataset, chunker):
93
  for i in tqdm(range(len(input_dataset))):
94
+ chunks = chunker.split(input_dataset[i][INPUT_TEXT_COL])
95
  for chunk in chunks:
96
  if chunk:
97
+ yield {INPUT_TEXT_COL: chunk}
98
 
99
 
100
+ def chunk_dataset(ds_name):
101
  logger.info("Update detected, chunking is scheduled")
102
+ input_ds = load_dataset(ds_name, split="+".join(INPUT_SPLITS))
103
  chunker = Chunker(
104
  strategy=chunk_config.strategy,
105
  split_seq=chunk_config.split_seq,
 
148
  result = await resp.json()
149
 
150
  tmp_file.write(
151
+ json.dumps({"vector": result[0], INPUT_TEXT_COL: sentence}) + "\n"
152
  )
153
 
154
 
155
  async def embed(input_ds, temp_file):
156
  semaphore = asyncio.BoundedSemaphore(embed_config.semaphore_bound)
157
  jobs = [
158
+ asyncio.create_task(embed_sent(row[INPUT_TEXT_COL], semaphore, temp_file))
159
+ for row in input_ds if row[INPUT_TEXT_COL].strip()
160
  ]
161
  logger.info(f"num chunks to embed: {len(jobs)}")
162
 
 
166
 
167
 
168
  def wake_up_endpoint(url):
169
+ logger.info("Starting up TEI endpoint")
170
  n_loop = 0
171
  while requests.get(
172
  url=url,
 
179
  logger.info("TEI endpoint is up")
180
 
181
 
182
+ def embed_dataset(ds_name):
183
  logger.info("Update detected, embedding is scheduled")
184
  wake_up_endpoint(TEI_URL)
185
+ input_ds = load_dataset(ds_name, split="+".join(INPUT_SPLITS))
186
  with tempfile.NamedTemporaryFile(mode="a", suffix=".jsonl") as temp_file:
187
  asyncio.run(embed(input_ds, temp_file))
188
 
src/models.py CHANGED
@@ -5,21 +5,13 @@ from typing import Literal
5
 
6
 
7
  class ChunkConfig(BaseModel):
8
- input_dataset: str
9
- input_splits: list[str]
10
- input_text_col: str
11
- output_dataset: str
12
- strategy: Literal["spacy", "sequence", "constant"]
13
- split_seq: str | list[str]
14
  chunk_len: int
15
  private: bool
16
 
17
 
18
  class EmbedConfig(BaseModel):
19
- input_dataset: str
20
- input_splits: list[str]
21
- input_text_col: str
22
- output_dataset: str
23
  private: bool
24
  semaphore_bound: int
25
 
 
5
 
6
 
7
  class ChunkConfig(BaseModel):
8
+ strategy: Literal["recursive", "sequence", "constant"]
9
+ split_seq: str
 
 
 
 
10
  chunk_len: int
11
  private: bool
12
 
13
 
14
  class EmbedConfig(BaseModel):
 
 
 
 
15
  private: bool
16
  semaphore_bound: int
17