Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
from sentence_transformers import SentenceTransformer
|
|
|
2 |
import os
|
|
|
3 |
from pymilvus import connections, utility, Collection, DataType, FieldSchema, CollectionSchema
|
4 |
import streamlit as st
|
5 |
|
@@ -7,8 +9,9 @@ import streamlit as st
|
|
7 |
@st.cache_resource
|
8 |
def load_sentence_transformer():
|
9 |
sent_model = SentenceTransformer('all-mpnet-base-v2')
|
|
|
10 |
print('loaded sentence transformer')
|
11 |
-
return sent_model
|
12 |
|
13 |
|
14 |
class TextVectorizer:
|
@@ -39,20 +42,23 @@ def find_similar_news(text: str, top_n: int=5):
|
|
39 |
param=search_params,
|
40 |
limit=top_n,
|
41 |
guarantee_timestamp=1,
|
42 |
-
output_fields=['article_desc'
|
43 |
|
44 |
|
45 |
-
output_dict = {"input_text": text, "similar_texts": [hit.entity.get('article_desc') for hits in result for hit in hits]
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
49 |
return f"<h4>Similar News Articles</h4><ol>{similar_txt}</ol>"
|
50 |
-
# return output_dict
|
51 |
|
52 |
|
53 |
vectorizer = TextVectorizer()
|
54 |
collection = get_milvus_collection()
|
55 |
-
sent_model = load_sentence_transformer()
|
56 |
|
57 |
|
58 |
def main():
|
@@ -62,18 +68,17 @@ def main():
|
|
62 |
desc = '''<p style="font-size: 13px;">
|
63 |
Embeddings of 300,000 news headlines are stored in Milvus vector database, used as a feature store.
|
64 |
Embeddings of the input headline are computed using sentence transformers (all-mpnet-base-v2).
|
65 |
-
Similar news headlines are retrieved from the vector database using Euclidean distance as similarity metric.
|
66 |
<span style="color: red;">This method (all-mpnet-base-v2) has the best performance compared to multi-qa-distilbert-cos-v1 fine-tuned using TSDAE
|
67 |
and extracting embeddings from fine-tuned DistilBERT classifier.</span>
|
68 |
</p>
|
69 |
'''
|
70 |
st.markdown(desc, unsafe_allow_html=True)
|
71 |
-
news_txt = st.text_area("Paste the headline of a news article:", "", height=
|
72 |
-
top_n = st.slider('Select number of similar articles to display', 1, 100,
|
73 |
|
74 |
-
if
|
75 |
result = find_similar_news(news_txt, top_n)
|
76 |
-
# st.write(result)
|
77 |
st.markdown(result, unsafe_allow_html=True)
|
78 |
|
79 |
|
|
|
1 |
from sentence_transformers import SentenceTransformer
|
2 |
+
from sentence_transformers.cross_encoder import CrossEncoder
|
3 |
import os
|
4 |
+
import numpy as np
|
5 |
from pymilvus import connections, utility, Collection, DataType, FieldSchema, CollectionSchema
|
6 |
import streamlit as st
|
7 |
|
|
|
9 |
@st.cache_resource
|
10 |
def load_sentence_transformer():
|
11 |
sent_model = SentenceTransformer('all-mpnet-base-v2')
|
12 |
+
ce_model = CrossEncoder('cross-encoder/stsb-distilroberta-base')
|
13 |
print('loaded sentence transformer')
|
14 |
+
return sent_model, ce_model
|
15 |
|
16 |
|
17 |
class TextVectorizer:
|
|
|
42 |
param=search_params,
|
43 |
limit=top_n,
|
44 |
guarantee_timestamp=1,
|
45 |
+
output_fields=['article_desc']) # which fields to return in output
|
46 |
|
47 |
|
48 |
+
output_dict = {"input_text": text, "similar_texts": [hit.entity.get('article_desc') for hits in result for hit in hits]}
|
49 |
+
texts = np.array(output_dict.get('similar_texts'))
|
50 |
+
ce_similarity_scores = np.array(ce_model.predict([[text, output_text] for output_text in texts]))
|
51 |
+
similarity_idxs = [*np.argsort(ce_similarity_scores)[::-1]]
|
52 |
+
texts = texts[similarity_idxs]
|
53 |
+
ce_similarity_scores = ce_similarity_scores[similarity_idxs]
|
54 |
+
txt_similarity = [f'<li><b>{txt}</b> (<i>similarity: {np.round(sim, 5)})</i></li>' for txt, sim in zip(texts, ce_similarity_scores)]
|
55 |
+
similar_txt = ''.join(txt_similarity)
|
56 |
return f"<h4>Similar News Articles</h4><ol>{similar_txt}</ol>"
|
|
|
57 |
|
58 |
|
59 |
vectorizer = TextVectorizer()
|
60 |
collection = get_milvus_collection()
|
61 |
+
sent_model, ce_model = load_sentence_transformer()
|
62 |
|
63 |
|
64 |
def main():
|
|
|
68 |
desc = '''<p style="font-size: 13px;">
|
69 |
Embeddings of 300,000 news headlines are stored in Milvus vector database, used as a feature store.
|
70 |
Embeddings of the input headline are computed using sentence transformers (all-mpnet-base-v2).
|
71 |
+
Similar news headlines are retrieved from the vector database using Euclidean distance as similarity metric and are reranked using cross encoder.
|
72 |
<span style="color: red;">This method (all-mpnet-base-v2) has the best performance compared to multi-qa-distilbert-cos-v1 fine-tuned using TSDAE
|
73 |
and extracting embeddings from fine-tuned DistilBERT classifier.</span>
|
74 |
</p>
|
75 |
'''
|
76 |
st.markdown(desc, unsafe_allow_html=True)
|
77 |
+
news_txt = st.text_area("Paste the headline of a news article and hit Ctrl+Enter:", "", height=30)
|
78 |
+
top_n = st.slider('Select the number of similar articles to display', 1, 100, 15)
|
79 |
|
80 |
+
if news_txt:
|
81 |
result = find_similar_news(news_txt, top_n)
|
|
|
82 |
st.markdown(result, unsafe_allow_html=True)
|
83 |
|
84 |
|