File size: 4,851 Bytes
e932fdf
c38bbc6
e932fdf
c38bbc6
0164e97
6115839
 
 
 
 
0164e97
 
 
e932fdf
c38bbc6
e932fdf
0164e97
e932fdf
 
0164e97
e932fdf
c38bbc6
6115839
e932fdf
 
 
6115839
c38bbc6
e932fdf
6115839
e932fdf
 
6115839
 
 
 
 
 
 
 
 
 
e932fdf
78ef349
 
 
 
 
 
 
 
 
 
6115839
0cfdb4e
 
 
 
78ef349
 
 
 
 
0cfdb4e
78ef349
0cfdb4e
 
 
6115839
 
0cfdb4e
6115839
 
97c72c9
 
78ef349
0ce5ee0
97c72c9
78ef349
6115839
8315f3e
0cfdb4e
 
6115839
 
 
 
0cfdb4e
 
6115839
e932fdf
 
6115839
05ac48b
6115839
05ac48b
6115839
 
78ef349
6115839
 
 
05ac48b
 
60c70dc
05ac48b
e932fdf
 
c38bbc6
cb83ba5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import gradio as gr
import fitz  # PyMuPDF for reading PDFs
import numpy as np
import pandas as pd
import logging
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances, manhattan_distances
from sklearn.metrics.pairwise import linear_kernel as dot_similarity  # For dot product
import umap
import plotly.graph_objects as go

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Initialize the model globally
model = SentenceTransformer('all-MiniLM-L6-v2')
logging.info("Model loaded successfully.")

def process_pdf(pdf_path):
    logging.info(f"Processing PDF: {pdf_path}")
    doc = fitz.open(pdf_path)
    texts = [page.get_text() for page in doc]
    logging.info("PDF processed successfully.")
    return " ".join(texts)

def create_embeddings(text):
    logging.info("Creating embeddings.")
    sentences = text.split(". ")  # A simple split; consider a more robust sentence splitter
    embeddings = model.encode(sentences)
    logging.info("Embeddings created successfully.")
    return embeddings, sentences

def calculate_distances(embeddings, query_embedding, metric):
    if metric == "cosine":
        distances = 1 - cosine_similarity(embeddings, [query_embedding])
    elif metric == "euclidean":
        distances = euclidean_distances(embeddings, [query_embedding])
    elif metric == "manhattan":
        distances = manhattan_distances(embeddings, [query_embedding])
    elif metric == "dot":
        distances = -dot_similarity(embeddings, [query_embedding])  # Negated for consistency with other metrics
    return distances.flatten()

def wrap_text(text, width=40):
    """
    Inserts HTML line breaks for Plotly hover text.
    :param text: The text to wrap.
    :param width: The maximum line width before wrapping.
    :return: Text with line breaks inserted.
    """
    wrapped_text = '<br>'.join([text[i:i+width] for i in range(0, len(text), width)])
    return wrapped_text

def generate_plotly_figure(query, pdf_file, metric):
    logging.info("Generating plot with Plotly.")
    query_embedding = model.encode([query])[0]
    text = process_pdf(pdf_file.name)
    embeddings, sentences = create_embeddings(text)
    
    # Wrap text for each sentence
    sentences_wrapped = [wrap_text(sentence) for sentence in sentences]
    all_sentences_wrapped = sentences_wrapped + [wrap_text(query)]  # Apply wrapping to the query as well
    
    all_embeddings = np.vstack([embeddings, query_embedding])
    
    umap_transform = umap.UMAP(n_neighbors=15, min_dist=0.0, n_components=2, random_state=42)
    umap_embeddings = umap_transform.fit_transform(all_embeddings)
    
    distances = calculate_distances(embeddings, query_embedding, metric)
    closest_indices = np.argsort(distances)[:5]  # Get indices of 5 closest sentences
    
    colors = ['green' if i in closest_indices else 'blue' for i in range(len(sentences))]
    colors.append('red')  # For the query
    
    fig = go.Figure(data=go.Scatter(x=umap_embeddings[:-1, 0], y=umap_embeddings[:-1, 1], mode='markers',
                                    marker=dict(color=colors[:-1]), text=all_sentences_wrapped[:-1],
                                    name='Chunks', hoverinfo='text'))
    fig.add_trace(go.Scatter(x=[umap_embeddings[-1, 0]], y=[umap_embeddings[-1, 1]], mode='markers',
                             marker=dict(color='red'), text=[all_sentences_wrapped[-1]], name='Query', hoverinfo='text'))
    fig.update_layout(title="UMAP Projection of Sentences with Query Highlight", xaxis_title="UMAP 1", yaxis_title="UMAP 2")
    
    logging.info("Plotly figure created successfully.")
    return fig

def gradio_interface(pdf_file, query, metric):
    logging.info("Gradio interface called with metric: " + metric)
    fig = generate_plotly_figure(query, pdf_file, metric)
    logging.info("Returning Plotly figure.")
    return fig

iface = gr.Interface(
    fn=gradio_interface,
    inputs=[
        gr.File(label="Upload a PDF"),
        gr.Textbox(label="Query"),
        gr.Radio(choices=["cosine", "euclidean", "manhattan", "dot"], label="Choose Distance Metric")
    ],
    outputs=gr.Plot(),
    title="Semantic Search Visualizer",
    description="""This tool allows you to upload a PDF document, input a query, and visualize the context of the document 
    as it relates to your query. It uses UMAP for dimensionality reduction and highlights the query and its closest contexts 
    within the document based on the selected distance metric. Choose from cosine, Euclidean, Manhattan, or dot product metrics 
    to explore different aspects of textual similarity.
    umap args: n_neighbors=15, min_dist=0.0,
    Green dots are the closest vectors
    """
)

if __name__ == "__main__":
    iface.launch()