|
from sentence_transformers import SentenceTransformer |
|
from sentence_transformers.cross_encoder import CrossEncoder |
|
import os |
|
import numpy as np |
|
from datetime import datetime |
|
from pymilvus import connections, utility, Collection, DataType, FieldSchema, CollectionSchema |
|
import streamlit as st |
|
import logging |
|
|
|
|
|
FORMAT = '%(asctime)s %(message)s' |
|
logging.basicConfig(format=FORMAT) |
|
logger = logging.getLogger('hf_logger') |
|
|
|
|
|
@st.cache_resource |
|
def load_sentence_transformer(): |
|
logger.warning('Entering load_sentence_transformer') |
|
sent_model = SentenceTransformer('all-mpnet-base-v2') |
|
ce_model = CrossEncoder('cross-encoder/stsb-distilroberta-base') |
|
logger.warning('Exiting load_sentence_transformer') |
|
return sent_model, ce_model |
|
|
|
|
|
class TextVectorizer: |
|
''' |
|
sentence transformers to extract sentence embeddings |
|
''' |
|
|
|
def vectorize_(self, x): |
|
logger.warning('Entering vectorize_()') |
|
sent_embeddings = sent_model.encode(x, normalize_embeddings=True) |
|
logger.warning('Exiting vectorize_()') |
|
return sent_embeddings |
|
|
|
|
|
@st.cache_resource |
|
def get_milvus_collection(): |
|
logger.warning('Entering get_milvus_collection()') |
|
uri = os.environ.get("URI") |
|
token = os.environ.get("TOKEN") |
|
connections.connect("default", uri=uri, token=token) |
|
collection_name = os.environ.get("COLLECTION_NAME") |
|
collection = Collection(name=collection_name) |
|
print(f"Loaded collection") |
|
logger.warning('Exiting get_milvus_collection()') |
|
return collection |
|
|
|
def find_similar_news(text: str, collection, vectorizer, sent_model, ce_model, top_n: int=100): |
|
logger.warning('Entering find_similar_news') |
|
search_params = {"metric_type": "IP"} |
|
search_vec = vectorizer.vectorize_(text) |
|
|
|
|
|
|
|
logger.warning('Querying Milvus for most similar results') |
|
results = collection.search([search_vec], |
|
anns_field='article_embed', |
|
param=search_params, |
|
limit=top_n, |
|
guarantee_timestamp=1, |
|
output_fields=['article_title', 'article_src', 'article_url', 'article_date'])[0] |
|
|
|
logger.warning('retrieved search results from Milvus') |
|
logger.warning('Computing cross encoder similarity scores') |
|
texts = [result.entity.get('article_title') for result in results] |
|
ce_similarity_scores = np.array(ce_model.predict([[text, output_text] for output_text in texts])) |
|
similarity_idxs = [*np.argsort(ce_similarity_scores)[::-1]] |
|
logger.warning('Retrieved cross encoder similarity scores') |
|
|
|
logger.warning('Generating HTML output') |
|
html_output = f'''<html>''' |
|
|
|
for n, i in enumerate(similarity_idxs): |
|
title_ = results[i].entity.get('article_title') |
|
date_ = results[i].entity.get('article_date') |
|
src_ = results[i].entity.get('article_src') |
|
url_ = results[i].entity.get('article_url') |
|
cross_encoder_similarity = str(np.round(ce_similarity_scores[i], 4)) |
|
cosine_similarity = str(np.round(results[i].distance, 4)) |
|
html_output += f'''<a style="font-weight: bold; font-size:18px; color: black;" href="{url_}" target="_blank">{n+1}. {title_}</a><br> |
|
<b>Date:</b> {date_}   <b>Source:</b> {src_}<br> |
|
<b>Cross encoder similarity:</b> {cross_encoder_similarity}   <b>Cosine similarity:</b> {cosine_similarity} |
|
<br><br> |
|
''' |
|
html_output += '</html>' |
|
logger.warning('Successfully generated HTML output') |
|
logger.warning('Exiting find_similar_news') |
|
return html_output |
|
|
|
|
|
vectorizer = TextVectorizer() |
|
collection = get_milvus_collection() |
|
sent_model, ce_model = load_sentence_transformer() |
|
|
|
|
|
try: |
|
logger.warning('Entering the application') |
|
st.markdown("<h3>Find Recent Similar News</h3>", unsafe_allow_html=True) |
|
desc = '''<p style="font-size: 13px;"> |
|
Embeddings of news headlines are stored in Milvus vector database, used as a feature store. |
|
The database is updated in realtime with new headlines using a CRON job. |
|
Embeddings of the input headline are computed using sentence transformers (all-mpnet-base-v2). |
|
Similar news headlines are retrieved from the vector database using Inner Product as similarity metric and are reranked using cross encoder. |
|
The embeddings are converted into unit vectors so that inner product can be used as cosine similarity, since Milvus doesn't support cosine similarity. |
|
</p> |
|
''' |
|
st.markdown(desc, unsafe_allow_html=True) |
|
news_txt = st.text_area("Paste the headline of a news article", "", height=30) |
|
top_n = st.slider('Select the number of similar articles to display', 1, 10, 5) |
|
|
|
if st.button("Submit"): |
|
result = find_similar_news(news_txt, collection, vectorizer, sent_model, ce_model, top_n) |
|
st.markdown(result, unsafe_allow_html=True) |
|
logger.warning('Exiting the application') |
|
except Exception as e: |
|
st.error(f'An unexpected error occured: \n{e}') |
|
|