File size: 5,469 Bytes
daf30d2 e97cba6 daf30d2 e97cba6 2c2dea5 daf30d2 2c2dea5 daf30d2 b5fd43f 2c2dea5 b5fd43f e97cba6 2c2dea5 e97cba6 b5fd43f 2c2dea5 daf30d2 2c2dea5 b5fd43f daf30d2 2c2dea5 daf30d2 2c2dea5 daf30d2 2c2dea5 daf30d2 2c2dea5 daf30d2 2c2dea5 e97cba6 2c2dea5 daf30d2 e97cba6 b5fd43f daf30d2 2c2dea5 b5d9b78 daf30d2 2c2dea5 daf30d2 a5a89f7 ebac715 2c2dea5 f043e94 2c2dea5 daf30d2 2c2dea5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
from sentence_transformers import SentenceTransformer
from sentence_transformers.cross_encoder import CrossEncoder
import os
import numpy as np
from datetime import datetime
from pymilvus import connections, utility, Collection, DataType, FieldSchema, CollectionSchema
import streamlit as st
import logging
FORMAT = '%(asctime)s %(message)s'
logging.basicConfig(format=FORMAT)
logger = logging.getLogger('hf_logger')
@st.cache_resource
def load_sentence_transformer():
logger.warning('Entering load_sentence_transformer')
sent_model = SentenceTransformer('all-mpnet-base-v2')
ce_model = CrossEncoder('cross-encoder/stsb-distilroberta-base')
logger.warning('Exiting load_sentence_transformer')
return sent_model, ce_model
class TextVectorizer:
'''
sentence transformers to extract sentence embeddings
'''
def vectorize_(self, x):
logger.warning('Entering vectorize_()')
sent_embeddings = sent_model.encode(x, normalize_embeddings=True)
logger.warning('Exiting vectorize_()')
return sent_embeddings
@st.cache_resource
def get_milvus_collection():
logger.warning('Entering get_milvus_collection()')
uri = os.environ.get("URI")
token = os.environ.get("TOKEN")
connections.connect("default", uri=uri, token=token)
collection_name = os.environ.get("COLLECTION_NAME")
collection = Collection(name=collection_name)
print(f"Loaded collection")
logger.warning('Exiting get_milvus_collection()')
return collection
def find_similar_news(text: str, collection, vectorizer, sent_model, ce_model, top_n: int=100):
logger.warning('Entering find_similar_news')
search_params = {"metric_type": "IP"}
search_vec = vectorizer.vectorize_(text)
# logger.warning('Querying Milvus for entity count')
# n_docs_in_collection = collection.query(expr="", output_fields = ["count(*)"])[0].get('count(*)')
# logger.warning(f'Retrieved entity count ({n_docs_in_collection}) from Milvus')
logger.warning('Querying Milvus for most similar results')
results = collection.search([search_vec],
anns_field='article_embed', # annotations field specified in the schema definition
param=search_params,
limit=top_n,
guarantee_timestamp=1,
output_fields=['article_title', 'article_src', 'article_url', 'article_date'])[0] # which fields to return in output
logger.warning('retrieved search results from Milvus')
logger.warning('Computing cross encoder similarity scores')
texts = [result.entity.get('article_title') for result in results]
ce_similarity_scores = np.array(ce_model.predict([[text, output_text] for output_text in texts]))
similarity_idxs = [*np.argsort(ce_similarity_scores)[::-1]]
logger.warning('Retrieved cross encoder similarity scores')
logger.warning('Generating HTML output')
html_output = f'''<html>'''
# html_output = f'''<html><h5>No. of news articles in database: {n_docs_in_collection}</h5>'''
for n, i in enumerate(similarity_idxs):
title_ = results[i].entity.get('article_title')
date_ = results[i].entity.get('article_date')
src_ = results[i].entity.get('article_src')
url_ = results[i].entity.get('article_url')
cross_encoder_similarity = str(np.round(ce_similarity_scores[i], 4))
cosine_similarity = str(np.round(results[i].distance, 4))
html_output += f'''<a style="font-weight: bold; font-size:18px; color: black;" href="{url_}" target="_blank">{n+1}. {title_}</a><br>
<b>Date:</b> {date_}   <b>Source:</b> {src_}<br>
<b>Cross encoder similarity:</b> {cross_encoder_similarity}   <b>Cosine similarity:</b> {cosine_similarity}
<br><br>
'''
html_output += '</html>'
logger.warning('Successfully generated HTML output')
logger.warning('Exiting find_similar_news')
return html_output
vectorizer = TextVectorizer()
collection = get_milvus_collection()
sent_model, ce_model = load_sentence_transformer()
try:
logger.warning('Entering the application')
st.markdown("<h3>Find Recent Similar News</h3>", unsafe_allow_html=True)
desc = '''<p style="font-size: 13px;">
Embeddings of news headlines are stored in Milvus vector database, used as a feature store.
The database is updated in realtime with new headlines using a CRON job.
Embeddings of the input headline are computed using sentence transformers (all-mpnet-base-v2).
Similar news headlines are retrieved from the vector database using Inner Product as similarity metric and are reranked using cross encoder.
The embeddings are converted into unit vectors so that inner product can be used as cosine similarity, since Milvus doesn't support cosine similarity.
</p>
'''
st.markdown(desc, unsafe_allow_html=True)
news_txt = st.text_area("Paste the headline of a news article", "", height=30)
top_n = st.slider('Select the number of similar articles to display', 1, 10, 5)
if st.button("Submit"):
result = find_similar_news(news_txt, collection, vectorizer, sent_model, ce_model, top_n)
st.markdown(result, unsafe_allow_html=True)
logger.warning('Exiting the application')
except Exception as e:
st.error(f'An unexpected error occured: \n{e}')
|