Polo123's picture
Update logic.py
26ae746 verified
from tabula import read_pdf
from bs4 import BeautifulSoup
import requests
from llama_cpp import Llama
from bertopic.representation import KeyBERTInspired, LlamaCPP
from sentence_transformers import SentenceTransformer
from umap import UMAP
from hdbscan import HDBSCAN
from bertopic import BERTopic
import PIL
import numpy as np
import datamapplot
import re
def get_links():
#reads table from pdf file
dfs = read_pdf("Artificial_Intelligence_Bookmarks_AwesomeList.pdf",pages="all") #upload pdf file
links = dfs[0]['Unnamed: 2'].to_list()
for i in range(len(dfs)-1):
links.extend(dfs[i+1]['Url'].to_list())
return links
#--------------------------------------
# text processing
def remove_tags(html):
# parse html content
soup = BeautifulSoup(html, "html.parser")
for data in soup(['style', 'script']):
# Remove tags
data.decompose()
# return data by retrieving the tag content
return ' '.join(soup.stripped_strings)
def remove_emoji(data):
emoj = re.compile("["
u"\U0001F600-\U0001F64F" # emoticons
u"\U0001F300-\U0001F5FF" # symbols & pictographs
u"\U0001F680-\U0001F6FF" # transport & map symbols
u"\U0001F1E0-\U0001F1FF" # flags (iOS)
u"\U00002500-\U00002BEF" # chinese char
u"\U00002702-\U000027B0"
u"\U000024C2-\U0001F251"
u"\U0001f926-\U0001f937"
u"\U00010000-\U0010ffff"
u"\u2640-\u2642"
u"\u2600-\u2B55"
u"\u200d"
u"\u23cf"
u"\u23e9"
u"\u231a"
u"\ufe0f" # dingbats
u"\u3030"
"]+", re.UNICODE)
return re.sub(emoj, '', data)
#-------------------------------------
def get_page(link):
try:
#print(link)
x = requests.get(link)
raw_html = x.text
clean_text = remove_tags(raw_html)[:1050]
clean_text = remove_emoji(clean_text)
return clean_text
except:
print(link)
def get_documents(links):
pre_processed_text = [get_page(link) for link in links]
while(None in pre_processed_text):
pre_processed_text.remove(None)
pre_processed_text = [i for i in pre_processed_text if len(i) > 999]
return pre_processed_text
#----------------------------------------
def get_topics(docs):
# Use llama.cpp to load in a Quantized LLM
llm = Llama(model_path="openhermes-2.5-mistral-7b.Q4_K_M.gguf", n_gpu_layers=-1, n_ctx=4096, stop=["Q:", "\n"])
prompt = """ Q:
I have a topic that contains the following documents:
[DOCUMENTS]
The topic is described by the following keywords: '[KEYWORDS]'.
Based on the above information, can you give a short label of the topic of at most 5 words?
A:
"""
representation_model = {
"KeyBERT": KeyBERTInspired(),
"LLM": LlamaCPP(llm, prompt=prompt),
}
# Pre-calculate embeddings
embedding_model = SentenceTransformer("BAAI/bge-small-en")
embeddings = embedding_model.encode(docs, show_progress_bar=True)
# Pre-reduce embeddings for visualization purposes
reduced_embeddings = UMAP(n_neighbors=15, n_components=2, min_dist=0.0, metric='cosine', random_state=42).fit_transform(embeddings)
# Define sub-models
umap_model = UMAP(n_neighbors=15, n_components=5, min_dist=0.0, metric='cosine', random_state=42)
hdbscan_model = HDBSCAN(min_cluster_size=2, metric='euclidean', cluster_selection_method='eom', prediction_data=True)
topic_model = BERTopic(
# Sub-models
embedding_model=embedding_model,
umap_model=umap_model,
hdbscan_model=hdbscan_model,
representation_model=representation_model,
# Hyperparameters
top_n_words=10,
verbose=True
)
# Train model
topics, probs = topic_model.fit_transform(docs, embeddings)
return topic_model
#-------------------------------
# Visualize Topics
def get_figure(topic_model):
# Prepare logo
bertopic_logo_response = requests.get(
"https://raw.githubusercontent.com/MaartenGr/BERTopic/master/images/logo.png",
stream=True,
headers={'User-Agent': 'My User Agent 1.0'}
)
bertopic_logo = np.asarray(PIL.Image.open(bertopic_logo_response.raw))
# Create a label for each document
llm_labels = [re.sub(r'\W+', ' ', label[0][0].split("\n")[0].replace('"', '')) for label in topic_model.get_topics(full=True)["LLM"].values()]
llm_labels = [label if label else "Unlabelled" for label in llm_labels]
all_labels = [llm_labels[topic+topic_model._outliers] if topic != -1 else "Unlabelled" for topic in topics]
# Run the visualization
fig = datamapplot.create_plot(
reduced_embeddings,
all_labels,
label_font_size=11,
title="ArXiv - BERTopic",
sub_title="Topics labeled with `openhermes-2.5-mistral-7b`",
label_wrap_width=20,
use_medoids=True,
logo=bertopic_logo,
logo_width=0.16
)
return fig