Spaces:
Running
Running
import asyncio | |
from typing import Callable, Optional, Union | |
import huggingface_hub | |
import semchunk | |
import tiktoken | |
import tokenizers | |
from datasets import Dataset, concatenate_datasets, load_dataset | |
from rich.progress import track | |
from transformers import PreTrainedTokenizer | |
TOKENIZER_OR_TOKEN_COUNTER = Union[ | |
str, | |
tiktoken.Encoding, | |
PreTrainedTokenizer, | |
tokenizers.Tokenizer, | |
Callable[[str], int], | |
] | |
class SemanticChunker: | |
""" | |
SemanticChunker is a class that chunks documents into smaller segments and | |
publishes them as datasets. | |
This class uses the `semchunk` library to break down large documents into | |
smaller, manageable chunks based on a specified tokenizer or token counter. | |
This is particularly useful for processing large text datasets where | |
smaller segments are needed for analysis or other operations. | |
!!! example "Example Usage" | |
```python | |
from medrag_multi_modal.semantic_chunking import SemanticChunker | |
chunker = SemanticChunker(chunk_size=256) | |
chunker.chunk( | |
document_dataset="geekyrakshit/grays-anatomy-test", | |
chunk_dataset_repo_id="geekyrakshit/grays-anatomy-chunks-test", | |
) | |
``` | |
Args: | |
tokenizer_or_token_counter (TOKENIZER_OR_TOKEN_COUNTER): The tokenizer or | |
token counter to be used for chunking. | |
chunk_size (Optional[int]): The size of each chunk. If not specified, the | |
default chunk size from `semchunk` will be used. | |
max_token_chars (Optional[int]): The maximum number of characters per token. | |
If not specified, the default value from `semchunk` will be used. | |
memoize (bool): Whether to memoize the chunking process for efficiency. | |
Default is True. | |
""" | |
def __init__( | |
self, | |
tokenizer_or_token_counter: TOKENIZER_OR_TOKEN_COUNTER = "o200k_base", | |
chunk_size: Optional[int] = None, | |
max_token_chars: Optional[int] = None, | |
memoize: bool = True, | |
) -> None: | |
self.chunker = semchunk.chunkerify( | |
tokenizer_or_token_counter, | |
chunk_size=chunk_size, | |
max_token_chars=max_token_chars, | |
memoize=memoize, | |
) | |
def chunk( | |
self, | |
document_dataset: Union[Dataset, str], | |
chunk_dataset_repo_id: Optional[str] = None, | |
overwrite_dataset: bool = False, | |
) -> Dataset: | |
""" | |
Chunks a document dataset into smaller segments and publishes them as a new dataset. | |
This function takes a document dataset, either as a HuggingFace Dataset object or a string | |
representing the dataset repository ID, and chunks the documents into smaller segments using | |
the specified chunker. The resulting chunks are then optionally published to a HuggingFace | |
dataset repository. | |
Args: | |
document_dataset (Union[Dataset, str]): The document dataset to be chunked. It can be either | |
a HuggingFace Dataset object or a string representing the dataset repository ID. | |
chunk_dataset_repo_id (Optional[str]): The repository ID of the HuggingFace dataset to publish | |
the chunks to, if provided. Defaults to None. | |
overwrite_dataset (bool): Whether to overwrite the existing dataset if it exists. Defaults to False. | |
Returns: | |
Dataset: A HuggingFace Dataset object containing the chunks. | |
""" | |
document_dataset = ( | |
load_dataset(document_dataset, split="corpus") | |
if isinstance(document_dataset, str) | |
else document_dataset | |
).to_list() | |
chunks = [] | |
async def process_document(idx, document): | |
document_chunks = self.chunker.chunk(str(document["text"])) | |
for chunk in document_chunks: | |
chunk_dict = {"document_idx": idx, "text": chunk} | |
for key, value in document.items(): | |
if key not in chunk_dict: | |
chunk_dict[key] = value | |
chunks.append(chunk_dict) | |
async def process_all_documents(): | |
tasks = [] | |
for idx, document in track( | |
enumerate(document_dataset), | |
total=len(document_dataset), | |
description="Chunking documents", | |
): | |
tasks.append(process_document(idx, document)) | |
await asyncio.gather(*tasks) | |
asyncio.run(process_all_documents()) | |
chunks.sort(key=lambda x: x["document_idx"]) | |
dataset = Dataset.from_list(chunks) | |
if chunk_dataset_repo_id: | |
if huggingface_hub.repo_exists(chunk_dataset_repo_id, repo_type="dataset"): | |
if not overwrite_dataset: | |
dataset = concatenate_datasets( | |
[ | |
dataset, | |
load_dataset(chunk_dataset_repo_id, split="chunks"), | |
] | |
) | |
dataset.push_to_hub(repo_id=chunk_dataset_repo_id, split="chunks") | |
return dataset | |