topic-clustering-global-dashboard / function /topic_clustering_social.py
cuongnguyen910's picture
Upload folder using huggingface_hub
5120311 verified
raw
history blame
6.46 kB
import json
import time
from .utils import get_sbert_embedding, clean_text
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics.pairwise import cosine_similarity
from nltk import sent_tokenize
import requests
# from clean_text import normalize_text
MAX_LENGTH_FEATURE = 250
MIN_LENGTH_FEATURE = 100
URL_CHECK_SPAM = "http://10.9.3.70:30036/predict"
def check_spam(docs):
json_body = {
"domain_id": "",
"records": [
{
"text": doc.get("message",""),
"idxcol": 1
} for doc in docs
]
}
result = requests.post(URL_CHECK_SPAM, json = json_body).json()
docs = [x for i,x in enumerate(docs) if result[i]["label"] == 0]
return docs
def preocess_feature(doc):
message = doc.get("message","")
paras = message.split("\n")
feature = ""
paras = [clean_text(x.strip(), normalize=False) for x in paras if x.strip() and len(x.strip()) > 10]
for para in paras:
if len(feature) + len(para) < MAX_LENGTH_FEATURE:
feature += " " +para
elif len(feature) < MIN_LENGTH_FEATURE:
sens = sent_tokenize(para)
for sen in sens:
if len(feature) + len(sen) < MAX_LENGTH_FEATURE or len(feature.strip()) < MIN_LENGTH_FEATURE:
feature += " " +sen
return feature
def topic_clustering(docs, distance_threshold, top_cluster=5, top_sentence=5, topn_summary=5, sorted_field='', max_doc_per_cluster=50, delete_message=True, is_check_spam = True):
# global model, model_en
docs = [x for x in docs if len(x.get("message","")) > 100]
docs = docs[:30000]
if is_check_spam:
docs = check_spam(docs)
result = {}
cluster_score = {}
t1 = time.time()
if len(docs) < 1:
return result
elif len(docs) == 1:
return {
"0": docs
}
# features = [doc.get('title', "") + ". " + doc.get('snippet', "") for doc in docs]
f_docs = []
for x in docs:
ft = preocess_feature(x)
if len(ft) > MIN_LENGTH_FEATURE:
x["title"] = ft
f_docs.append(x)
docs = f_docs
features = [x["title"] for x in docs ]
# with open("feature", 'w') as f:
# json.dump(features, f, ensure_ascii = False)
# print(features)
vectors = get_sbert_embedding(features)
clusteror = AgglomerativeClustering(n_clusters=None, compute_full_tree=True, affinity='cosine',
linkage='complete', distance_threshold=distance_threshold)
clusteror.fit(vectors)
print(f"Time encode + clustering: {time.time() - t1} {clusteror.n_clusters_}")
for i in range(clusteror.n_clusters_):
result[str(i + 1)] = []
cluster_score[str(i + 1)] = 0
for i in range(len(clusteror.labels_)):
cluster_no = clusteror.labels_[i]
if docs[i].get('domain','') not in ["cungcau.vn","baomoi.com","news.skydoor.net"]:
response_doc = {}
response_doc = docs[i]
score = response_doc.get('score', 0)
if not docs[i].get('message','').strip():
continue
if score > cluster_score[str(cluster_no + 1)]:
cluster_score[str(cluster_no + 1)] = score
if 'domain' in docs[i]:
response_doc['domain'] = docs[i]['domain']
if 'url' in docs[i]:
response_doc['url'] = docs[i]['url']
if 'title' in docs[i]:
response_doc['title'] = clean_text(docs[i]['title'])
if 'snippet' in docs[i]:
response_doc['snippet'] = clean_text(docs[i]['snippet'])
if 'created_time' in docs[i]:
response_doc['created_time'] = docs[i]['created_time']
if "sentiment" in docs[i]:
response_doc['sentiment'] = docs[i]['sentiment']
if 'message' in docs[i]:
title = docs[i].get('title','')
snippet = docs[i].get('snippet','')
message = docs[i].get('message','')
# if title.strip():
# split_mess = message.split(title)
# if len(split_mess) > 1:
# message = title.join(split_mess[1:])
# if snippet.strip():
# split_mess = message.split(snippet)
# if len(split_mess) > 1:
# message = snippet.join(split_mess[1:])
response_doc['message'] = clean_text(message)
if 'id' in docs[i]:
response_doc['id'] = docs[i]['id']
# response_doc['score'] = 0.0
# response_doc['title_summarize'] = []
# response_doc['content_summary'] = ""
# response_doc['total_facebook_viral'] = 0
result[str(cluster_no + 1)].append(response_doc)
empty_clus_ids = []
for x in result:
result[x] = sorted(result[x], key=lambda i: -len(i.get('message','')))
if len( result[x]) > 0:
# if len(result[x]) > 1:
# result[x] = check_duplicate_title_domain(result[x])
result[x][0]['num_docs'] = len(result[x])
result[x][0]['max_score'] = cluster_score[x]
else:
empty_clus_ids.append(x)
for x in empty_clus_ids:
result.pop(x,None)
result = dict( sorted(result.items(), key=lambda i: -len(i[1]))[:top_cluster])
return result
# return post_processing(result, top_cluster=top_cluster, top_sentence=top_sentence, topn_summary=topn_summary, sorted_field = sorted_field, max_doc_per_cluster=max_doc_per_cluster, delete_message=delete_message)
if __name__ == '__main__':
# with open("/home2/vietle/DA-Report/social.json", 'r') as f:
# docs = json.load(f)[:2000]
with open("/home2/vietle/news-cms/topic_summarization/data/news_cms.social.json", 'r') as f:
docs = json.load(f)[:10000]
clusters = topic_clustering(docs, distance_threshold=0.2, top_cluster=5000, top_sentence=5, topn_summary=5, sorted_field='', max_doc_per_cluster=50, delete_message=False)
with open("/home2/vietle/news-cms/topic_summarization/cluster/news_cms.social.json", 'w') as f:
json.dump(clusters,f, ensure_ascii =False)