from sentence_transformers import SentenceTransformer from sentence_transformers.cross_encoder import CrossEncoder import os import numpy as np from datetime import datetime from pymilvus import connections, utility, Collection, DataType, FieldSchema, CollectionSchema import streamlit as st import logging FORMAT = '%(asctime)s %(message)s' logging.basicConfig(format=FORMAT) logger = logging.getLogger('hf_logger') @st.cache_resource def load_sentence_transformer(): logger.warning('Entering load_sentence_transformer') sent_model = SentenceTransformer('all-mpnet-base-v2') ce_model = CrossEncoder('cross-encoder/stsb-distilroberta-base') logger.warning('Exiting load_sentence_transformer') return sent_model, ce_model class TextVectorizer: ''' sentence transformers to extract sentence embeddings ''' def vectorize_(self, x): logger.warning('Entering vectorize_()') sent_embeddings = sent_model.encode(x, normalize_embeddings=True) logger.warning('Exiting vectorize_()') return sent_embeddings @st.cache_resource def get_milvus_collection(): logger.warning('Entering get_milvus_collection()') uri = os.environ.get("URI") token = os.environ.get("TOKEN") connections.connect("default", uri=uri, token=token) collection_name = os.environ.get("COLLECTION_NAME") collection = Collection(name=collection_name) print(f"Loaded collection") logger.warning('Exiting get_milvus_collection()') return collection def find_similar_news(text: str, collection, vectorizer, sent_model, ce_model, top_n: int=100): logger.warning('Entering find_similar_news') search_params = {"metric_type": "IP"} search_vec = vectorizer.vectorize_(text) # logger.warning('Querying Milvus for entity count') # n_docs_in_collection = collection.query(expr="", output_fields = ["count(*)"])[0].get('count(*)') # logger.warning(f'Retrieved entity count ({n_docs_in_collection}) from Milvus') logger.warning('Querying Milvus for most similar results') results = collection.search([search_vec], anns_field='article_embed', # annotations field specified in the schema definition param=search_params, limit=top_n, guarantee_timestamp=1, output_fields=['article_title', 'article_src', 'article_url', 'article_date'])[0] # which fields to return in output logger.warning('retrieved search results from Milvus') logger.warning('Computing cross encoder similarity scores') texts = [result.entity.get('article_title') for result in results] ce_similarity_scores = np.array(ce_model.predict([[text, output_text] for output_text in texts])) similarity_idxs = [*np.argsort(ce_similarity_scores)[::-1]] logger.warning('Retrieved cross encoder similarity scores') logger.warning('Generating HTML output') html_output = f'''''' # html_output = f'''
No. of news articles in database: {n_docs_in_collection}
''' for n, i in enumerate(similarity_idxs): title_ = results[i].entity.get('article_title') date_ = results[i].entity.get('article_date') src_ = results[i].entity.get('article_src') url_ = results[i].entity.get('article_url') cross_encoder_similarity = str(np.round(ce_similarity_scores[i], 4)) cosine_similarity = str(np.round(results[i].distance, 4)) html_output += f'''{n+1}. {title_}
Date: {date_}    Source: {src_}
Cross encoder similarity: {cross_encoder_similarity}    Cosine similarity: {cosine_similarity}

''' html_output += '' logger.warning('Successfully generated HTML output') logger.warning('Exiting find_similar_news') return html_output vectorizer = TextVectorizer() collection = get_milvus_collection() sent_model, ce_model = load_sentence_transformer() try: logger.warning('Entering the application') st.markdown("

Find Recent Similar News

", unsafe_allow_html=True) desc = '''

Embeddings of news headlines are stored in Milvus vector database, used as a feature store. The database is updated in realtime with new headlines using a CRON job. Embeddings of the input headline are computed using sentence transformers (all-mpnet-base-v2). Similar news headlines are retrieved from the vector database using Inner Product as similarity metric and are reranked using cross encoder. The embeddings are converted into unit vectors so that inner product can be used as cosine similarity, since Milvus doesn't support cosine similarity.

''' st.markdown(desc, unsafe_allow_html=True) news_txt = st.text_area("Paste the headline of a news article", "", height=30) top_n = st.slider('Select the number of similar articles to display', 1, 10, 5) if st.button("Submit"): result = find_similar_news(news_txt, collection, vectorizer, sent_model, ce_model, top_n) st.markdown(result, unsafe_allow_html=True) logger.warning('Exiting the application') except Exception as e: st.error(f'An unexpected error occured: \n{e}')