|
import argparse |
|
from dotenv import load_dotenv |
|
import asyncio |
|
import gradio as gr |
|
import numpy as np |
|
import time |
|
import json |
|
import os |
|
import tempfile |
|
import requests |
|
import logging |
|
|
|
from aiohttp import ClientSession |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from datasets import Dataset, load_dataset |
|
from tqdm import tqdm |
|
from tqdm.asyncio import tqdm_asyncio |
|
|
|
load_dotenv() |
|
|
|
USERNAME = os.getenv("USERNAME") |
|
PWD = os.getenv("USER_PWD") |
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
SEMAPHORE_BOUND = os.getenv("SEMAPHORE_BOUND", "5") |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class Chunker: |
|
def __init__(self, strategy, split_seq=".", chunk_len=512): |
|
self.split_seq = split_seq |
|
self.chunk_len = chunk_len |
|
if strategy == "recursive": |
|
|
|
self.split = RecursiveCharacterTextSplitter( |
|
chunk_size=chunk_len, |
|
separators=[split_seq] |
|
).split_text |
|
if strategy == "sequence": |
|
self.split = self.seq_splitter |
|
if strategy == "constant": |
|
self.split = self.const_splitter |
|
|
|
def seq_splitter(self, text): |
|
return text.split(self.split_seq) |
|
|
|
def const_splitter(self, text): |
|
return [ |
|
text[i * self.chunk_len:(i + 1) * self.chunk_len] |
|
for i in range(int(np.ceil(len(text) / self.chunk_len))) |
|
] |
|
|
|
|
|
def generator(input_ds, input_text_col, chunker): |
|
for i in tqdm(range(len(input_ds))): |
|
chunks = chunker.split(input_ds[i][input_text_col]) |
|
for chunk in chunks: |
|
if chunk: |
|
yield {input_text_col: chunk} |
|
|
|
|
|
async def embed_sent(sentence, embed_in_text_col, semaphore, tei_url, tmp_file): |
|
async with semaphore: |
|
payload = { |
|
"inputs": sentence, |
|
"truncate": True |
|
} |
|
|
|
async with ClientSession( |
|
headers={ |
|
"Content-Type": "application/json", |
|
"Authorization": f"Bearer {HF_TOKEN}" |
|
} |
|
) as session: |
|
async with session.post(tei_url, json=payload) as resp: |
|
if resp.status != 200: |
|
raise RuntimeError(await resp.text()) |
|
result = await resp.json() |
|
|
|
tmp_file.write( |
|
json.dumps({"vector": result[0], embed_in_text_col: sentence}) + "\n" |
|
) |
|
|
|
|
|
async def embed_ds(input_ds, tei_url, embed_in_text_col, temp_file): |
|
semaphore = asyncio.BoundedSemaphore(int(SEMAPHORE_BOUND)) |
|
jobs = [ |
|
asyncio.create_task(embed_sent(row[embed_in_text_col], embed_in_text_col, semaphore, tei_url, temp_file)) |
|
for row in input_ds if row[embed_in_text_col].strip() |
|
] |
|
logger.info(f"num chunks to embed: {len(jobs)}") |
|
|
|
tic = time.time() |
|
await tqdm_asyncio.gather(*jobs) |
|
logger.info(f"embed time: {time.time() - tic}") |
|
|
|
|
|
def wake_up_endpoint(url): |
|
logger.info("Starting up TEI endpoint") |
|
n_loop = 0 |
|
while requests.get( |
|
url=url, |
|
headers={"Authorization": f"Bearer {HF_TOKEN}"} |
|
).status_code != 200: |
|
time.sleep(2) |
|
n_loop += 1 |
|
if n_loop > 40: |
|
raise gr.Error("TEI endpoint is unavailable") |
|
logger.info("TEI endpoint is up") |
|
|
|
|
|
def chunk_embed(input_ds, input_splits, input_text_col, chunk_out_ds, |
|
strategy, split_seq, chunk_len, embed_out_ds, tei_url, private): |
|
gr.Info("Started chunking") |
|
try: |
|
input_splits = [spl.strip() for spl in input_splits.split(",") if spl] |
|
input_ds = load_dataset(input_ds, split="+".join(input_splits), token=HF_TOKEN) |
|
chunker = Chunker(strategy, split_seq, chunk_len) |
|
except Exception as e: |
|
raise gr.Error(str(e)) |
|
|
|
gen_kwargs = { |
|
"input_ds": input_ds, |
|
"input_text_col": input_text_col, |
|
"chunker": chunker |
|
} |
|
chunked_ds = Dataset.from_generator(generator, gen_kwargs=gen_kwargs) |
|
chunked_ds.push_to_hub( |
|
chunk_out_ds, |
|
private=private, |
|
token=HF_TOKEN |
|
) |
|
|
|
gr.Info("Done chunking") |
|
logger.info("Done chunking") |
|
|
|
try: |
|
wake_up_endpoint(tei_url) |
|
with tempfile.NamedTemporaryFile(mode="a", suffix=".jsonl") as temp_file: |
|
asyncio.run(embed_ds(chunked_ds, tei_url, input_text_col, temp_file)) |
|
|
|
embedded_ds = Dataset.from_json(temp_file.name) |
|
embedded_ds.push_to_hub( |
|
embed_out_ds, |
|
private=private, |
|
token=HF_TOKEN |
|
) |
|
except Exception as e: |
|
raise gr.Error(str(e)) |
|
|
|
gr.Info("Done embedding") |
|
logger.info("Done embedding") |
|
|
|
|
|
def change_dropdown(choice): |
|
if choice == "recursive": |
|
return [ |
|
gr.Textbox(visible=True), |
|
gr.Textbox(visible=True) |
|
] |
|
elif choice == "sequence": |
|
return [ |
|
gr.Textbox(visible=True), |
|
gr.Textbox(visible=False) |
|
] |
|
else: |
|
return [ |
|
gr.Textbox(visible=False), |
|
gr.Textbox(visible=True) |
|
] |
|
|
|
|
|
def main(args): |
|
demo= gr.Blocks(theme='sudeepshouche/minimalist'): |
|
gr.Markdown( |
|
""" |
|
## Chunk and embed |
|
""" |
|
) |
|
input_ds = gr.Textbox(lines=1, label="Input dataset name") |
|
with gr.Row(): |
|
input_splits = gr.Textbox(lines=1, label="Input dataset splits", placeholder="train, test") |
|
input_text_col = gr.Textbox(lines=1, label="Input text column name", placeholder="text") |
|
chunk_out_ds = gr.Textbox(lines=1, label="Chunked dataset name") |
|
with gr.Row(): |
|
dropdown = gr.Dropdown( |
|
["recursive", "sequence", "constant"], label="Chunking strategy", |
|
info="'recursive' uses a Langchain recursive tokenizer, 'sequence' splits texts by a chosen sequence, " |
|
"'constant' makes chunks of the constant size", |
|
scale=2 |
|
) |
|
split_seq = gr.Textbox( |
|
lines=1, |
|
interactive=True, |
|
visible=False, |
|
label="Sequence", |
|
info="A text sequence to split on", |
|
placeholder="\n\n" |
|
) |
|
chunk_len = gr.Textbox( |
|
lines=1, |
|
interactive=True, |
|
visible=False, |
|
label="Length", |
|
info="The length of chunks to split into in characters", |
|
placeholder="512" |
|
) |
|
dropdown.change(fn=change_dropdown, inputs=dropdown, outputs=[split_seq, chunk_len]) |
|
embed_out_ds = gr.Textbox(lines=1, label="Embedded dataset name") |
|
private = gr.Checkbox(label="Make output datasets private") |
|
tei_url = gr.Textbox(lines=1, label="TEI endpoint url") |
|
with gr.Row(): |
|
clear = gr.ClearButton( |
|
components=[input_ds, input_splits, input_text_col, chunk_out_ds, |
|
dropdown, split_seq, chunk_len, embed_out_ds, tei_url, private] |
|
) |
|
embed_btn = gr.Button("Submit") |
|
embed_btn.click( |
|
fn=chunk_embed, |
|
inputs=[input_ds, input_splits, input_text_col, chunk_out_ds, |
|
dropdown, split_seq, chunk_len, embed_out_ds, tei_url, private] |
|
) |
|
|
|
demo.queue() |
|
demo.launch(auth=(USERNAME, PWD), server_name="0.0.0.0", server_port=args.port) |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="A MAGIC example by ConceptaTech") |
|
parser.add_argument("--port", type=int, default=7860, help="Port to expose Gradio app") |
|
|
|
args = parser.parse_args() |
|
main(args) |
|
|