File size: 5,330 Bytes
d62eaf7
 
 
 
 
8c41dd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d62eaf7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c41dd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d62eaf7
 
8c41dd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0800e8
d62eaf7
 
 
 
 
 
 
 
 
 
 
 
 
 
e0800e8
 
 
d62eaf7
 
 
 
 
 
e0800e8
d62eaf7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c41dd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import plotly.graph_objs as go
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
import plotly.express as px
import numpy as np
import os
import pprint
import codecs
import chardet
import gradio as gr
from langchain.llms import HuggingFacePipeline
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain import OpenAI, ConversationChain, LLMChain, PromptTemplate
from langchain.chains.conversation.memory import ConversationalBufferWindowMemory
from EdgeGPT import Chatbot


def get_content(input_file):
    # Read the input file in binary mode
    with open(input_file, 'rb') as f:
        raw_data = f.read()

    # Detect the encoding of the file
    result = chardet.detect(raw_data)
    encoding = result['encoding']

    # Decode the contents using the detected encoding
    with codecs.open(input_file, 'r', encoding=encoding) as f:
        raw_text = f.read()

    # Return the content of the input file
    return raw_text


def split_text(input_file, chunk_size=1000, chunk_overlap=0):
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap,
        length_function=len,
    )

    basename = os.path.basename(input_file)
    basename = os.path.splitext(basename)[0]
    raw_text = get_content(input_file=input_file)

    texts = text_splitter.split_text(text=raw_text)
    metadatas = [{"source": f"{basename}[{i}]"} for i in range(len(texts))]
    docs = text_splitter.create_documents(texts=texts, metadatas=metadatas)

    return texts, metadatas, docs


def create_docs(input_file):
    # Create a text splitter object with a separator character
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=1000,
        chunk_overlap=0,
        length_function=len,
    )

    basename = os.path.basename(input_file)
    basename = os.path.splitext(basename)[0]
    texts = get_content(input_file=input_file)
    metadatas = {'source': basename}
    docs = text_splitter.create_documents(texts=[texts], metadatas=[metadatas])
    return docs


def get_similar_docs(query, index, k=5):
    similar_docs = index.similarity_search(query=query, k=k)
    result = [(d.summary, d.metadata) for d in similar_docs]
    return result


def convert_to_html(similar_docs):
    result = []
    for summary, metadata in similar_docs:
        record = '<tr><td>' + summary + '</td><td>' + \
            metadata['source'] + '</td></tr>'
        result.append(record)
    html = '<table><thead><th>Page Content</th><th>Source</th></thead><tbody>' + \
        '\n'.join(result) + '</tbody></table>'
    return html


def create_similarity_plot(embeddings, labels, query, n_clusters=3):
    # Only include embeddings that have corresponding labels
    embeddings_with_labels = [
        embedding for i, embedding in enumerate(embeddings) if i < len(labels)]

    # Reduce the dimensionality of the embeddings using PCA
    pca = PCA(n_components=3)
    pca_embeddings = pca.fit_transform(embeddings_with_labels)

    # Cluster the embeddings using k-means
    kmeans = KMeans(n_clusters=n_clusters)
    kmeans.fit(embeddings_with_labels)

    # Create a trace for the query point
    query_trace = go.Scatter3d(
        x=[pca_embeddings[-1, 0]],
        y=[pca_embeddings[-1, 1]],
        z=[pca_embeddings[-1, 2]],
        mode='markers',
        marker=dict(
            color='black',
            symbol='diamond',
            size=10
        ),
        name=f"Query: '{query}'"
    )

    # Create a trace for the other points
    points_trace = go.Scatter3d(
        x=pca_embeddings[:, 0],
        y=pca_embeddings[:, 1],
        z=pca_embeddings[:, 2],
        mode='markers',
        marker=dict(
            color=kmeans.labels_,
            colorscale=px.colors.qualitative.Alphabet,
            size=5
        ),
        text=labels,
        name='Points'
    )

    # Create the figure
    fig = go.Figure(data=[query_trace, points_trace])

    # Add a title and legend
    fig.update_layout(
        title="3D Similarity Plot",
        legend_title_text="Cluster"
    )

    # Show the plot
    fig.show()


def plot_similarities(query, index, embeddings=HuggingFaceEmbeddings(), k=5):
    query_embeddings = embeddings.embed_query(text=query)

    similar_docs = get_similar_docs(query=query, index=index, k=k)
    texts = []
    for d in similar_docs:
        texts.append(d[0])

    embeddings_array = embeddings.embed_documents(texts=texts)

    # Get the index of the query point
    query_index = len(embeddings_array) - 1

    create_similarity_plot(
        embeddings=embeddings_array,
        labels=texts,
        query_index=query_index,
        n_clusters=3
    )


def start_ui(index):
    def query_index(query):
        similar_docs = get_similar_docs(query=query, index=index)
        formatted_output = convert_to_html(similar_docs=similar_docs)
        return formatted_output

    # Define input and output types
    input = gr.inputs.Textbox(lines=2)
    output = gr.outputs.HTML()

    # Create interface object
    iface = gr.Interface(fn=query_index,
                         inputs=input,
                         outputs=output)

    # Launch interface
    iface.launch()