|
import os |
|
import time |
|
import pdfplumber |
|
import docx |
|
import nltk |
|
import gradio as gr |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from langchain.vectorstores import FAISS |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter, SentenceTextSplitter |
|
from sentence_transformers import SentenceTransformer |
|
from transformers import AutoTokenizer |
|
from nltk import sent_tokenize |
|
from typing import List, Tuple |
|
from transformers import AutoModel, AutoTokenizer |
|
|
|
|
|
nltk.download('punkt') |
|
|
|
FILES_DIR = './files' |
|
|
|
|
|
MODELS = { |
|
'e5-base': "danielheinz/e5-base-sts-en-de", |
|
'multilingual-e5-base': "multilingual-e5-base", |
|
'paraphrase-miniLM': "paraphrase-multilingual-MiniLM-L12-v2", |
|
'paraphrase-mpnet': "paraphrase-multilingual-mpnet-base-v2", |
|
'gte-large': "gte-large", |
|
'gbert-base': "gbert-base" |
|
} |
|
|
|
class FileHandler: |
|
@staticmethod |
|
def extract_text(file_path): |
|
ext = os.path.splitext(file_path)[-1].lower() |
|
if ext == '.pdf': |
|
return FileHandler._extract_from_pdf(file_path) |
|
elif ext == '.docx': |
|
return FileHandler._extract_from_docx(file_path) |
|
elif ext == '.txt': |
|
return FileHandler._extract_from_txt(file_path) |
|
else: |
|
raise ValueError(f"Unsupported file type: {ext}") |
|
|
|
@staticmethod |
|
def _extract_from_pdf(file_path): |
|
with pdfplumber.open(file_path) as pdf: |
|
return ' '.join([page.extract_text() for page in pdf.pages]) |
|
|
|
@staticmethod |
|
def _extract_from_docx(file_path): |
|
doc = docx.Document(file_path) |
|
return ' '.join([para.text for para in doc.paragraphs]) |
|
|
|
@staticmethod |
|
def _extract_from_txt(file_path): |
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
return f.read() |
|
|
|
class EmbeddingModel: |
|
def __init__(self, model_name, max_tokens=None): |
|
self.model = HuggingFaceEmbeddings(model_name=model_name) |
|
self.max_tokens = max_tokens |
|
|
|
def embed(self, text): |
|
return self.model.embed_documents([text]) |
|
|
|
def process_files(model_name, split_strategy, chunk_size=500, overlap_size=50, max_tokens=None): |
|
|
|
text = "" |
|
for file in os.listdir(FILES_DIR): |
|
file_path = os.path.join(FILES_DIR, file) |
|
text += FileHandler.extract_text(file_path) |
|
|
|
|
|
if split_strategy == 'sentence': |
|
splitter = SentenceTextSplitter(chunk_size=chunk_size, chunk_overlap=overlap_size) |
|
else: |
|
splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=overlap_size) |
|
|
|
chunks = splitter.split_text(text) |
|
model = EmbeddingModel(MODELS[model_name], max_tokens=max_tokens) |
|
embeddings = model.embed(text) |
|
|
|
return embeddings, chunks |
|
|
|
def search_embeddings(query, model_name, top_k): |
|
model = HuggingFaceEmbeddings(model_name=MODELS[model_name]) |
|
embeddings = model.embed_query(query) |
|
return embeddings |
|
|
|
def calculate_statistics(embeddings): |
|
|
|
return {"tokens": len(embeddings), "time_taken": time.time()} |
|
|
|
|
|
def upload_file(file, model_name, split_strategy, chunk_size, overlap_size, max_tokens, query, top_k): |
|
with open(os.path.join(FILES_DIR, file.name), "wb") as f: |
|
f.write(file.read()) |
|
|
|
|
|
embeddings, chunks = process_files(model_name, split_strategy, chunk_size, overlap_size, max_tokens) |
|
|
|
|
|
results = search_embeddings(query, model_name, top_k) |
|
|
|
|
|
stats = calculate_statistics(embeddings) |
|
|
|
return {"results": results, "stats": stats} |
|
|
|
|
|
iface = gr.Interface( |
|
fn=upload_file, |
|
inputs=[ |
|
gr.File(label="Upload File"), |
|
gr.Dropdown(choices=list(MODELS.keys()), label="Embedding Model"), |
|
gr.Radio(choices=["sentence", "recursive"], label="Split Strategy"), |
|
gr.Slider(100, 1000, step=100, value=500, label="Chunk Size"), |
|
gr.Slider(0, 100, step=10, value=50, label="Overlap Size"), |
|
gr.Slider(50, 500, step=50, value=200, label="Max Tokens"), |
|
gr.Textbox(label="Search Query"), |
|
gr.Slider(1, 10, step=1, value=5, label="Top K") |
|
], |
|
outputs="json" |
|
) |
|
|
|
iface.launch() |
|
|
|
|