Ransaka commited on
Commit
d06496c
1 Parent(s): 4d31406

Added files

Browse files
.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
app.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import altair as alt
5
+ from PIL import Image
6
+ from embeddings.embeddings import load_model
7
+ from sentence_transformers import util
8
+ # Create sample data
9
+ data = pd.DataFrame({
10
+ 'Name': ['Alice', 'Bob', 'Charlie', 'David'],
11
+ 'Age': [25, 30, 22, 35]
12
+ })
13
+
14
+ # Sample PNG file
15
+ image = Image.open('plots\clusters.png')
16
+
17
+ # Sample HTML chart
18
+ chart_data = pd.read_csv(r"data\top_cluster_dataset.csv",dtype={'Headline': str, 'x': np.float64, 'y': np.float64, 'labels': str})
19
+
20
+ # Create a Streamlit app
21
+ st.set_page_config(page_title="Sample Webpage", page_icon=":bar_chart:")
22
+
23
+ # Define tabs
24
+ tabs = ["Search", "Clustering Results"]
25
+ selected_tab = st.sidebar.radio("Select a Tab", tabs)
26
+
27
+ # Main content
28
+ if selected_tab == "Search":
29
+ sample_sentences = chart_data['Headline'].sample(10, random_state=1).tolist()
30
+ st.title("Calculate Sentences Similarity")
31
+ # select model to use dropdown
32
+ st.subheader("Select a model to use")
33
+ model_list = ["Ransaka/SinhalaRoberta","keshan/SinhalaBERTo"]
34
+ selected_model = st.selectbox("Select Model", model_list)
35
+ model = load_model(selected_model)
36
+
37
+ sentence1 = st.text_input("Enter Sentence 1", "")
38
+ sentence2 = st.text_input("Enter Sentence 2", "")
39
+
40
+ if sentence1 and sentence2:
41
+ # add button to calculate similarity
42
+ if st.button("Calculate Similarity"):
43
+ with st.spinner('Calculating Similarity...'):
44
+ # Calculate similarity
45
+ similarity = util.pytorch_cos_sim(model.encode(sentence1), model.encode(sentence2))[0][0]
46
+ if similarity > 0.7:
47
+ st.success(f"Sentences are similar (Score: {similarity:.3f})")
48
+ elif similarity > 0.5:
49
+ st.warning(f"Sentences are somewhat similar (Score: {similarity:.3f})")
50
+ else:
51
+ st.error(f"Sentences are not similar (Score: {similarity:.3f})")
52
+ else:
53
+ st.write("Enter two sentences to calculate similarity. Or start with sample sentences below.")
54
+ # change radio button to randomize sentences and show sample sentences
55
+ if st.button("Randomize Sentences"):
56
+ sample_sentences = chart_data['Headline'].sample(10).tolist()
57
+ for sentence in sample_sentences:
58
+ # show sample sentences in small font
59
+ st.write(sentence)
60
+
61
+ elif selected_tab == "Clustering Results":
62
+ st.title("Clustering Results Tab")
63
+
64
+ # Display PNG image
65
+ st.subheader("Static PNG File")
66
+ st.image(image, use_column_width=False, caption='Static PNG File',width=750)
67
+
68
+ altair_chart = alt.Chart(chart_data).mark_circle().encode(
69
+ x='x',
70
+ y='y',
71
+ color='labels',
72
+ tooltip='Headline'
73
+ ).properties(
74
+ width=750,
75
+ height=500
76
+ ).interactive()
77
+ # Display chart
78
+ st.subheader("Interactive Chart for top clusters")
79
+ st.altair_chart(altair_chart, use_container_width=False, theme="streamlit")
80
+
81
+ # Dropdown functionality to update DataFrame
82
+ st.subheader("Select a cluster")
83
+ unique_clusters = chart_data['labels'].unique().tolist()
84
+ selected_value = st.selectbox("Select Value", unique_clusters)
85
+
86
+ # Filter and display results based on selected cluster
87
+ if selected_value:
88
+ filtered_data = chart_data[chart_data['labels'].str.contains(selected_value, case=False)].sample(10)[['Headline']].reset_index(drop=True)
89
+ st.dataframe(filtered_data,width=750)
90
+ else:
91
+ st.write("Select a cluster to display results.")
92
+
clustering/clustering.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hdbscan
2
+ import umap
3
+ import numpy as np
4
+ import pandas as pd
5
+ import matplotlib.pyplot as plt
6
+
7
+ def load_data():
8
+ # Load data
9
+ embeddings = np.load(r'data\top_cluster_embeddings.npy')
10
+ return embeddings
11
+
12
+ def get_clusters(embeddings):
13
+ # Get clusters
14
+ umap_embeddings = umap.UMAP(
15
+ n_neighbors=15,
16
+ n_components=15,
17
+ metric='cosine'
18
+ ).fit_transform(embeddings)
19
+
20
+ cluster = hdbscan.HDBSCAN(
21
+ min_cluster_size=30,
22
+ metric='euclidean',
23
+ cluster_selection_method='eom'
24
+ ).fit(umap_embeddings)
25
+
26
+ return cluster.labels_
27
+
28
+ def get_2d_data_for_plotting(embeddings):
29
+ # Get 2D data for plotting
30
+ umap_embeddings = umap.UMAP(
31
+ n_neighbors=15,
32
+ n_components=2,
33
+ metric='cosine'
34
+ ).fit_transform(embeddings)
35
+
36
+ return umap_embeddings
37
+
38
+ def plot_clusters(embeddings, cluster_labels):
39
+ umap_data = get_2d_data_for_plotting(embeddings)
40
+ result = pd.DataFrame(umap_data, columns=['x', 'y'])
41
+ result['labels'] = cluster_labels
42
+
43
+ # Visualize clusters
44
+ fig, ax = plt.subplots(figsize=(20, 10))
45
+ outliers = result.loc[result.labels == -1, :]
46
+ clustered = result.loc[result.labels != -1, :]
47
+ plt.scatter(outliers.x, outliers.y, color='#BDBDBD', s=0.05)
48
+ plt.scatter(clustered.x, clustered.y, c=clustered.labels, s=0.05, cmap='hsv_r')
49
+ plt.colorbar()
50
+ plt.savefig(r'plots\clusters.png', dpi=300)
51
+
52
+ def main():
53
+ embeddings = load_data()
54
+ cluster_labels = get_clusters(embeddings)
55
+ plot_clusters(embeddings, cluster_labels)
56
+
57
+ if __name__ == '__main__':
58
+ main()
data/top_cluster_dataset.csv ADDED
The diff for this file is too large to render. See raw diff
 
embeddings/__int__.py ADDED
File without changes
embeddings/embeddings.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file contains the code for the embeddings.
3
+ Tested models as follows:
4
+ - Ransaka/SinhalaRoberta
5
+ - keshan/SinhalaBERTo
6
+ This file used Ransaka/SinhalaRoberta model for the embeddings.
7
+
8
+ You can download the model from huggingface.co
9
+ - https://huggingface.co/Ransaka/SinhalaRoberta
10
+ - https://huggingface.co/keshan/SinhalaBERTo
11
+
12
+ You can download dataset from kaggle.com
13
+ - https://www.kaggle.com/datasets/ransakaravihara/hiru-news-set3
14
+
15
+ """
16
+ import random
17
+ import numpy as np
18
+ import pandas as pd
19
+
20
+ import torch
21
+ from sentence_transformers import SentenceTransformer, models,util
22
+
23
+ model_id = "Ransaka/SinhalaRoberta"
24
+
25
+ def load_and_process_data(file_path:str)->list:
26
+ """
27
+ This function loads the data from the file path and process it.
28
+ """
29
+ def processor(text:str)->str:
30
+ """Only addresses the most common issues in the dataset"""
31
+ return text\
32
+ .replace("\u200d","")\
33
+ .replace("Read More..","")\
34
+ .replace("ඡායාරූප","")\
35
+ .replace("\xa0","")\
36
+ .replace("වීඩියෝ","")\
37
+ .replace("()","")
38
+
39
+ def basic_processing(series:pd.Series)->pd.Series:
40
+ """Applies the processor function to a pandas series"""
41
+ return series\
42
+ .apply(processor)
43
+
44
+ df = pd.read_csv(file_path)
45
+ df.dropna(inplace=True)
46
+ df['Headline'] = basic_processing(df['Headline'])
47
+ # df['fullText'] = basic_processing(df['fullText'])
48
+
49
+ #only headlines used for the embeddings
50
+ sentences = df['Headline'].values.tolist()
51
+ random.shuffle(sentences)
52
+ return sentences
53
+
54
+ def load_model(model_id:str)->SentenceTransformer:
55
+ """
56
+ This function loads the model from the huggingface.co
57
+ """
58
+ word_embedding_model = models.Transformer(model_id, max_seq_length=514)
59
+ pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
60
+
61
+ model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
62
+ return model
63
+
64
+ def get_embeddings(model: SentenceTransformer, sentences: list)->list:
65
+ """
66
+ This function returns the embeddings for the given sentences.
67
+ """
68
+ return model.encode(sentences)
69
+
70
+ def save_embeddings(embeddings: list, file_path: str):
71
+ """
72
+ This function saves the embeddings to the given file path.
73
+ """
74
+ np.save(file_path, embeddings)
75
+
76
+ def load_embeddings(file_path: str)->list:
77
+ """
78
+ This function loads the embeddings from the given file path.
79
+ """
80
+ return np.load(file_path)
81
+
82
+ def get_similar(model:SentenceTransformer,embeddings: list, query: str, top_k: int = 5)->list:
83
+ """
84
+ This function returns the top k similar sentences for the given query.
85
+ """
86
+ query_embedding = model.encode([query])[0]
87
+ cos_scores = util.pytorch_cos_sim(query_embedding, embeddings)[0]
88
+ top_results = torch.topk(cos_scores, k=top_k)
89
+ return top_results
90
+
91
+ if __name__ == "__main__":
92
+ file_path = r"data\top_cluster_dataset.csv"
93
+
94
+ #load and process data
95
+ sentences = load_and_process_data(file_path)
96
+ model = load_model(model_id)
97
+
98
+ #get embeddings
99
+ embeddings = get_embeddings(model, sentences)
100
+ save_embeddings(embeddings, r"data\embeddings.npy")
plots/chart.html ADDED
The diff for this file is too large to render. See raw diff
 
plots/clusters.png ADDED
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ altair==5.1.1
2
+ faiss-cpu==1.7.4
3
+ hdbscan==0.8.1
4
+ numba==0.58.0
5
+ numpy==1.25.2
6
+ sentence-transformers==2.2.2
7
+ sentencepiece==0.1.99
8
+ streamlit==1.27.0
9
+ tokenizers==0.13.3
10
+ torch==2.0.1
11
+ transformers==4.33.2
12
+ umap-learn==0.5.4
search_demo.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sample results:
3
+ Query: ක්ෂය රෝග මර්දන ව්යාපාරයේ පී.සී.ආර්. යන්ත්ර 36 භාවිතයට ගන්නැයි ඉල්ලීමක්
4
+ Results:
5
+ - ක්ෂය රෝග මර්දන ව්යාපාරයේ පී.සී.ආර්. යන්ත්ර 36 භාවිතයට ගන්නැයි ඉල්ලීමක්
6
+ - ජාතික රෝහලේ අද සිට දිනකට පී.සී.ආර් පරීක්ෂණ 200 ක්
7
+ - පී.සී.ආර්.සාම්පල රසායනාගාරවල ගොඩගැසී ඇතැයි වෛද්ය සංගමයෙන් චෝදනා
8
+ - කොරෝනා සොයන්න දිනකට පී.සී.ආර්. පරීක්ෂණ, 6000 ක් කිරීමේ සැලසුම්
9
+
10
+ Query: පොළොන්නරුව මහරෝහලේ අකුරට වැඩ කිරීමේ වෘත්තීය ක්රියාමාර්ගයක්
11
+ Results:
12
+ - පොළොන්නරුව මහරෝහලේ අකුරට වැඩ කිරීමේ වෘත්තීය ක්රියාමාර්ගයක්
13
+ - අකුරට වැඩ කළ රේගු වෘත්තීය සමිති, වර්ජනයකට සැරසේ
14
+ - ජාතික සත්ත්වෝද්යාන වෘත්තීය සමිති වැඩ වර්ජනයක
15
+ - ජල සම්පාදන වෘත්තීය සමිති ඒකාබද්ධ සන්ධානයෙන් වෘත්තීය ක්රියාමාර්ගවලට
16
+
17
+ Query: අංගොඩ අයි ඩී එච් රෝහලේ ඩෙංගු විශේෂ ප්රතිකාර ඒකකය තවම නැහැ
18
+ Results:
19
+ - අංගොඩ අයි ඩී එච් රෝහලේ ඩෙංගු විශේෂ ප්රතිකාර ඒකකය තවම නැහැ
20
+ - අයි.ඩී.එච්. රෝහලෙන් පැන ගිය කොරෝනා ආසාදිත කාන්තාව සොයා තවදුරටත් මෙහෙයුම්
21
+ - අයි.ඩී.එච්. රෝහලෙන් පැන්න කොරෝනා ආසාදිත කාන්තාව සොයන මෙහෙයුම අඛණ්ඩව
22
+ - කොරෝනා වෛරසය ආසාදනය වී ඇත්දැයි සැකයෙන්, සතියක් තුල 71ක් අයි.ඩී.එච් රෝහලට
23
+
24
+ Query: කමිටු ගැන විශ්වාසයක් නැහැ - මාළඹේ පෞද්ගලික වෛද්ය විද්යාලයීය දෙමාපිය සංසදය
25
+ Results:
26
+ - කමිටු ගැන විශ්වාසයක් නැහැ - මාළඹේ පෞද්ගලික වෛද්ය විද්යාලයීය දෙමාපිය සංසදය
27
+ - මාළඹේ වෛද්ය විද්යාලයීය දෙමාපිය සංසදය ජනපති ලේකම් කාර්යාලයට
28
+ - සයිටම් ගැටළුව වෙනතකට යොමුකිරීමට ආණ්ඩුව උපක්රම යොදනවා - වෛද්ය පීඨ ශිෂ්ය ක්රියාකාරී කමිටුව
29
+ - එකම විසඳුම සයිටම් අහෝසි කිරීමයි - වෛද්ය පීඨ ශිෂ්ය ක්රියාකාරී කමිටුව
30
+ """
31
+
32
+ from vector_search.vector_search import search_demo
33
+
34
+ if __name__ == "__main__":
35
+ search_demo(top_k=4)
vector_search/__init__.py ADDED
File without changes
vector_search/vector_search.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is used to search the most similar vectors in the database using the faiss library.
3
+ used indexer class grabbed from daily-llama repo (https://github.com/Ransaka/daily-llama)
4
+ """
5
+ import numpy as np
6
+ import pandas as pd
7
+ from embeddings.embeddings import load_model, model_id
8
+
9
+ # from daily llama repo
10
+ import faiss
11
+
12
+ class Indexer:
13
+ def __init__(self, embed_vec):
14
+ self.embeddings_vec = embed_vec
15
+ self.build_index()
16
+
17
+ def build_index(self):
18
+ """
19
+ Build the index for the embeddings.
20
+
21
+ This function initializes the index for the embeddings. It calculates the dimension (self.d)
22
+ of the embeddings vector and creates an IndexFlatL2 object (self.index) for the given dimension.
23
+ It then adds the embeddings vector (self.embeddings_vec) to the index.
24
+
25
+ Parameters:
26
+ - None
27
+
28
+ Return:
29
+ - None
30
+ """
31
+ self.d = self.embeddings_vec.shape[1]
32
+ self.index = faiss.IndexFlatL2(self.d)
33
+ self.index.add(self.embeddings_vec)
34
+
35
+ def topk(self, vector, k = 4):
36
+ """
37
+ A function that takes in a vector and an optional parameter k and returns the indices of the k nearest neighbors in the index.
38
+
39
+ Parameters:
40
+ vector: A numpy array representing the input vector.
41
+ k (optional): An integer representing the number of nearest neighbors to retrieve. Defaults to 4 if not specified.
42
+
43
+ Returns:
44
+ I: A numpy array containing the indices of the k nearest neighbors in the index.
45
+ """
46
+ # vec = self.retreaver.encode(text)['embeddings'].detach().cpu().numpy()
47
+ _, I = self.index.search(vector, k)
48
+ return I
49
+
50
+
51
+ def get_embeddings_vec(file_path):
52
+
53
+ """
54
+ This function loads the embeddings from the given file path.
55
+
56
+ Parameters:
57
+ - file_path: A string representing the path to the embeddings file.
58
+
59
+ Return:
60
+ - embeddings_vec: A numpy array containing the embeddings.
61
+ """
62
+ return np.load(file_path)
63
+
64
+ def get_similar(indexer, text_embeddings, top_k = 5):
65
+ """
66
+ This function returns the top k similar sentences for the given query.
67
+
68
+ Parameters:
69
+ - indexer: An Indexer object representing the indexer for the embeddings.
70
+ - text_embeddings: A np.array representing the query embeddings.
71
+ - top_k (optional): An integer representing the number of nearest neighbors to retrieve. Defaults to 4 if not specified.
72
+
73
+ Return:
74
+ - top_results: A numpy array containing the indices of the k nearest neighbors in the index.
75
+ """
76
+ return indexer.topk(text_embeddings,k=top_k).flatten()
77
+
78
+ def search_demo(test_queries:list=None,top_k:int=1):
79
+ """
80
+ This function returns the top k similar sentences for the given query.
81
+ """
82
+ model = load_model(model_id)
83
+ embeddings_vec = get_embeddings_vec(r"data\top_cluster_embeddings.npy")
84
+ indexer = Indexer(embeddings_vec)
85
+
86
+ cluster_dataset = pd.read_csv(r"data\top_cluster_dataset.csv",usecols=['Headline'])
87
+ search_space = cluster_dataset['Headline'].values.tolist()
88
+ if test_queries is None:
89
+ test_queries = [
90
+ "ක්ෂය රෝග මර්දන ව්යාපාරයේ පී.සී.ආර්. යන්ත්ර 36 භාවිතයට ගන්නැයි ඉල්ලීමක්",
91
+ "පොළොන්නරුව මහරෝහලේ අකුරට වැඩ කිරීමේ වෘත්තීය ක්රියාමාර්ගයක්",
92
+ "අංගොඩ අයි ඩී එච් රෝහලේ ඩෙංගු විශේෂ ප්රතිකාර ඒකකය තවම නැහැ ",
93
+ "කමිටු ගැන විශ්වාසයක් නැහැ - මාළඹේ පෞද්ගලික වෛද්ය විද්යාලයීය දෙමාපිය සංසදය"
94
+ ]
95
+
96
+ for query in test_queries:
97
+ query_embeddings = model.encode(query).reshape(1,-1)
98
+ print("Query: ", query)
99
+ print("Results: ")
100
+ for index in get_similar(indexer, query_embeddings, top_k = top_k):
101
+ print("\t-",search_space[index])
102
+ print()