ksvmuralidhar commited on
Commit
e97cba6
·
verified ·
1 Parent(s): 85a8da7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -13
app.py CHANGED
@@ -1,5 +1,7 @@
1
  from sentence_transformers import SentenceTransformer
 
2
  import os
 
3
  from pymilvus import connections, utility, Collection, DataType, FieldSchema, CollectionSchema
4
  import streamlit as st
5
 
@@ -7,8 +9,9 @@ import streamlit as st
7
  @st.cache_resource
8
  def load_sentence_transformer():
9
  sent_model = SentenceTransformer('all-mpnet-base-v2')
 
10
  print('loaded sentence transformer')
11
- return sent_model
12
 
13
 
14
  class TextVectorizer:
@@ -39,20 +42,23 @@ def find_similar_news(text: str, top_n: int=5):
39
  param=search_params,
40
  limit=top_n,
41
  guarantee_timestamp=1,
42
- output_fields=['article_desc', 'article_category']) # which fields to return in output
43
 
44
 
45
- output_dict = {"input_text": text, "similar_texts": [hit.entity.get('article_desc') for hits in result for hit in hits],
46
- "text_category": [hit.entity.get('article_category') for hits in result for hit in hits]}
47
- txt_category = [f'<li><b>{txt}</b> (<i>{cat}</i>)</li>' for txt, cat in zip(output_dict.get('similar_texts'), output_dict.get('text_category'))]
48
- similar_txt = ''.join(txt_category)
 
 
 
 
49
  return f"<h4>Similar News Articles</h4><ol>{similar_txt}</ol>"
50
- # return output_dict
51
 
52
 
53
  vectorizer = TextVectorizer()
54
  collection = get_milvus_collection()
55
- sent_model = load_sentence_transformer()
56
 
57
 
58
  def main():
@@ -62,18 +68,17 @@ def main():
62
  desc = '''<p style="font-size: 13px;">
63
  Embeddings of 300,000 news headlines are stored in Milvus vector database, used as a feature store.
64
  Embeddings of the input headline are computed using sentence transformers (all-mpnet-base-v2).
65
- Similar news headlines are retrieved from the vector database using Euclidean distance as similarity metric.
66
  <span style="color: red;">This method (all-mpnet-base-v2) has the best performance compared to multi-qa-distilbert-cos-v1 fine-tuned using TSDAE
67
  and extracting embeddings from fine-tuned DistilBERT classifier.</span>
68
  </p>
69
  '''
70
  st.markdown(desc, unsafe_allow_html=True)
71
- news_txt = st.text_area("Paste the headline of a news article:", "", height=50)
72
- top_n = st.slider('Select number of similar articles to display', 1, 100, 10)
73
 
74
- if st.button("Submit"):
75
  result = find_similar_news(news_txt, top_n)
76
- # st.write(result)
77
  st.markdown(result, unsafe_allow_html=True)
78
 
79
 
 
1
  from sentence_transformers import SentenceTransformer
2
+ from sentence_transformers.cross_encoder import CrossEncoder
3
  import os
4
+ import numpy as np
5
  from pymilvus import connections, utility, Collection, DataType, FieldSchema, CollectionSchema
6
  import streamlit as st
7
 
 
9
  @st.cache_resource
10
  def load_sentence_transformer():
11
  sent_model = SentenceTransformer('all-mpnet-base-v2')
12
+ ce_model = CrossEncoder('cross-encoder/stsb-distilroberta-base')
13
  print('loaded sentence transformer')
14
+ return sent_model, ce_model
15
 
16
 
17
  class TextVectorizer:
 
42
  param=search_params,
43
  limit=top_n,
44
  guarantee_timestamp=1,
45
+ output_fields=['article_desc']) # which fields to return in output
46
 
47
 
48
+ output_dict = {"input_text": text, "similar_texts": [hit.entity.get('article_desc') for hits in result for hit in hits]}
49
+ texts = np.array(output_dict.get('similar_texts'))
50
+ ce_similarity_scores = np.array(ce_model.predict([[text, output_text] for output_text in texts]))
51
+ similarity_idxs = [*np.argsort(ce_similarity_scores)[::-1]]
52
+ texts = texts[similarity_idxs]
53
+ ce_similarity_scores = ce_similarity_scores[similarity_idxs]
54
+ txt_similarity = [f'<li><b>{txt}</b> (<i>similarity: {np.round(sim, 5)})</i></li>' for txt, sim in zip(texts, ce_similarity_scores)]
55
+ similar_txt = ''.join(txt_similarity)
56
  return f"<h4>Similar News Articles</h4><ol>{similar_txt}</ol>"
 
57
 
58
 
59
  vectorizer = TextVectorizer()
60
  collection = get_milvus_collection()
61
+ sent_model, ce_model = load_sentence_transformer()
62
 
63
 
64
  def main():
 
68
  desc = '''<p style="font-size: 13px;">
69
  Embeddings of 300,000 news headlines are stored in Milvus vector database, used as a feature store.
70
  Embeddings of the input headline are computed using sentence transformers (all-mpnet-base-v2).
71
+ Similar news headlines are retrieved from the vector database using Euclidean distance as similarity metric and are reranked using cross encoder.
72
  <span style="color: red;">This method (all-mpnet-base-v2) has the best performance compared to multi-qa-distilbert-cos-v1 fine-tuned using TSDAE
73
  and extracting embeddings from fine-tuned DistilBERT classifier.</span>
74
  </p>
75
  '''
76
  st.markdown(desc, unsafe_allow_html=True)
77
+ news_txt = st.text_area("Paste the headline of a news article and hit Ctrl+Enter:", "", height=30)
78
+ top_n = st.slider('Select the number of similar articles to display', 1, 100, 15)
79
 
80
+ if news_txt:
81
  result = find_similar_news(news_txt, top_n)
 
82
  st.markdown(result, unsafe_allow_html=True)
83
 
84