asoria HF staff commited on
Commit
64136bc
·
1 Parent(s): 5f46fb3

Changing sentence transformer

Browse files
Files changed (2) hide show
  1. app.py +54 -29
  2. requirements.txt +3 -1
app.py CHANGED
@@ -6,7 +6,11 @@ from bertopic import BERTopic
6
  import pandas as pd
7
  import gradio as gr
8
  from bertopic.representation import KeyBERTInspired
9
- import spaces
 
 
 
 
10
 
11
  logging.basicConfig(
12
  level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
@@ -24,7 +28,7 @@ def get_parquet_urls(dataset, config, split):
24
  if "error" in parquet_files:
25
  raise Exception(f"Error fetching parquet files: {parquet_files['error']}")
26
  parquet_urls = [file["url"] for file in parquet_files["parquet_files"]]
27
- logging.info(f"Parquet files: {parquet_urls}")
28
  return ",".join(f"'{url}'" for url in parquet_urls)
29
 
30
 
@@ -34,7 +38,7 @@ def get_docs_from_parquet(parquet_urls, column, offset, limit):
34
  logging.debug(f"Dataframe: {df.head(5)}")
35
  return df[column].tolist()
36
 
37
- @spaces.GPU
38
  def generate_topics(dataset, config, split, column, nested_column):
39
  logging.info(
40
  f"Generating topics for {dataset} with config {config} {split} {column} {nested_column}"
@@ -45,39 +49,60 @@ def generate_topics(dataset, config, split, column, nested_column):
45
  chunk_size = 300
46
  offset = 0
47
  representation_model = KeyBERTInspired()
48
-
49
- docs = get_docs_from_parquet(parquet_urls, column, offset, chunk_size)
50
-
51
- base_model = BERTopic(
52
- representation_model=representation_model, min_topic_size=15
53
- ).fit(docs)
54
-
55
- yield base_model.get_topic_info(), base_model.visualize_topics()
56
-
 
 
 
 
57
  while True:
 
 
 
 
58
  offset = offset + chunk_size
59
  if not docs or offset >= limit:
60
  break
61
 
62
- docs = get_docs_from_parquet(parquet_urls, column, offset, chunk_size)
63
- logging.info(f"------------> New chunk data {offset=} {chunk_size=}")
64
- logging.info(docs[:5])
65
-
66
  new_model = BERTopic(
67
- "english", representation_model=representation_model, min_topic_size=15
68
- ).fit(docs)
69
- updated_model = BERTopic.merge_models([base_model, new_model])
70
- nr_new_topics = len(set(updated_model.topics_)) - len(set(base_model.topics_))
71
- new_topics = list(updated_model.topic_labels_.values())[-nr_new_topics:]
72
- logging.info("The following topics are newly found:")
73
- logging.info(f"{new_topics}\n")
74
-
75
- # Update the base model
76
- base_model = updated_model
77
-
 
 
 
 
 
 
 
 
78
  logging.info(base_model.get_topic_info())
79
- yield base_model.get_topic_info(), base_model.visualize_topics()
80
-
 
 
 
 
 
 
 
 
 
81
  return base_model.get_topic_info(), base_model.visualize_topics()
82
 
83
 
 
6
  import pandas as pd
7
  import gradio as gr
8
  from bertopic.representation import KeyBERTInspired
9
+ from umap import UMAP
10
+
11
+ # from cuml.cluster import HDBSCAN
12
+ # from cuml.manifold import UMAP
13
+ from sentence_transformers import SentenceTransformer
14
 
15
  logging.basicConfig(
16
  level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
 
28
  if "error" in parquet_files:
29
  raise Exception(f"Error fetching parquet files: {parquet_files['error']}")
30
  parquet_urls = [file["url"] for file in parquet_files["parquet_files"]]
31
+ logging.debug(f"Parquet files: {parquet_urls}")
32
  return ",".join(f"'{url}'" for url in parquet_urls)
33
 
34
 
 
38
  logging.debug(f"Dataframe: {df.head(5)}")
39
  return df[column].tolist()
40
 
41
+
42
  def generate_topics(dataset, config, split, column, nested_column):
43
  logging.info(
44
  f"Generating topics for {dataset} with config {config} {split} {column} {nested_column}"
 
49
  chunk_size = 300
50
  offset = 0
51
  representation_model = KeyBERTInspired()
52
+ base_model = None
53
+ # docs = get_docs_from_parquet(parquet_urls, column, offset, chunk_size)
54
+
55
+ # base_model = BERTopic(
56
+ # "english", representation_model=representation_model, min_topic_size=15
57
+ # )
58
+ # base_model.fit_transform(docs)
59
+
60
+ # yield base_model.get_topic_info(), base_model.visualize_topics()
61
+ # Create instances of GPU-accelerated UMAP and HDBSCAN
62
+ # umap_model = UMAP(n_components=5, n_neighbors=15, min_dist=0.0)
63
+ # hdbscan_model = HDBSCAN(min_samples=10, gen_min_span_tree=True)
64
+ sentence_model = SentenceTransformer("all-MiniLM-L6-v2", device="cuda")
65
  while True:
66
+ docs = get_docs_from_parquet(parquet_urls, column, offset, chunk_size)
67
+ logging.info(f"------------> New chunk data {offset=} {chunk_size=}")
68
+ embeddings = sentence_model.encode(docs, show_progress_bar=True, batch_size=100)
69
+ logging.info(f"Embeddings shape: {embeddings.shape}")
70
  offset = offset + chunk_size
71
  if not docs or offset >= limit:
72
  break
73
 
 
 
 
 
74
  new_model = BERTopic(
75
+ "english",
76
+ embedding_model=sentence_model,
77
+ representation_model=representation_model,
78
+ min_topic_size=15, # umap_model=umap_model, hdbscan_model=hdbscan_model
79
+ )
80
+ logging.info("Fitting new model")
81
+ new_model.fit(docs, embeddings)
82
+ logging.info("End fitting new model")
83
+ if base_model is not None:
84
+ updated_model = BERTopic.merge_models([base_model, new_model])
85
+ nr_new_topics = len(set(updated_model.topics_)) - len(
86
+ set(base_model.topics_)
87
+ )
88
+ new_topics = list(updated_model.topic_labels_.values())[-nr_new_topics:]
89
+ logging.info("The following topics are newly found:")
90
+ logging.info(f"{new_topics}\n")
91
+ base_model = updated_model
92
+ else:
93
+ base_model = new_model
94
  logging.info(base_model.get_topic_info())
95
+ reduced_embeddings = UMAP(
96
+ n_neighbors=10, n_components=2, min_dist=0.0, metric="cosine"
97
+ ).fit_transform(embeddings)
98
+ logging.info(f"Reduced embeddings shape: {reduced_embeddings.shape}")
99
+ yield (
100
+ base_model.get_topic_info(),
101
+ new_model.visualize_documents(
102
+ docs, embeddings=embeddings
103
+ ), # TODO: Visualize the merged models
104
+ )
105
+ logging.info("Finished processing all data")
106
  return base_model.get_topic_info(), base_model.visualize_topics()
107
 
108
 
requirements.txt CHANGED
@@ -4,4 +4,6 @@ umap-learn
4
  sentence-transformers
5
  datamapplot
6
  bertopic
7
- pandas
 
 
 
4
  sentence-transformers
5
  datamapplot
6
  bertopic
7
+ pandas
8
+ torch
9
+ cuml-cu11