from sentence_transformers import SentenceTransformer from sentence_transformers.cross_encoder import CrossEncoder import os import numpy as np from pymilvus import connections, utility, Collection, DataType, FieldSchema, CollectionSchema import streamlit as st @st.cache_resource def load_sentence_transformer(): sent_model = SentenceTransformer('all-mpnet-base-v2') ce_model = CrossEncoder('cross-encoder/stsb-distilroberta-base') print('loaded sentence transformer') return sent_model, ce_model class TextVectorizer: ''' sentence transformers to extract sentence embeddings ''' def vectorize(self, x): sen_embeddings = sent_model.encode(x) return sen_embeddings @st.cache_resource def get_milvus_collection(): uri = os.environ.get("URI") token = os.environ.get("TOKEN") connections.connect("default", uri=uri, token=token) print(f"Connected to DB") collection_name = os.environ.get("COLLECTION_NAME") collection = Collection(name=collection_name) collection.load() return collection def find_similar_news(text: str, top_n: int=5): search_params = {"metric_type": "L2"} search_vec = vectorizer.vectorize(text) result = collection.search([search_vec], anns_field='article_embed', # annotations field specified in the schema definition param=search_params, limit=top_n, guarantee_timestamp=1, output_fields=['article_desc']) # which fields to return in output output_dict = {"input_text": text, "similar_texts": [hit.entity.get('article_desc') for hits in result for hit in hits]} texts = np.array(output_dict.get('similar_texts')) ce_similarity_scores = np.array(ce_model.predict([[text, output_text] for output_text in texts])) similarity_idxs = [*np.argsort(ce_similarity_scores)[::-1]] texts = texts[similarity_idxs] ce_similarity_scores = ce_similarity_scores[similarity_idxs] txt_similarity = [f'
Embeddings of 300,000 news headlines are stored in Milvus vector database, used as a feature store. Embeddings of the input headline are computed using sentence transformers (all-mpnet-base-v2). Similar news headlines are retrieved from the vector database using Euclidean distance as similarity metric and are reranked using cross encoder. This method (all-mpnet-base-v2) has the best performance compared to multi-qa-distilbert-cos-v1 fine-tuned using TSDAE and extracting embeddings from fine-tuned DistilBERT classifier.
''' st.markdown(desc, unsafe_allow_html=True) news_txt = st.text_area("Paste the headline of a news article", "", height=30) top_n = st.slider('Select the number of similar articles to display', 1, 100, 10) if st.button("Submit"): result = find_similar_news(news_txt, top_n) st.markdown(result, unsafe_allow_html=True) if __name__ == "__main__": main()