YchKhan commited on
Commit
c6cfdf4
·
verified ·
1 Parent(s): 6c698a2

Update classification.py

Browse files
Files changed (1) hide show
  1. classification.py +4 -3
classification.py CHANGED
@@ -173,12 +173,12 @@ def process_categories(categories, model):
173
  def match_categories(df, category_df, treshold=0.45):
174
  for topic in category_df['topic']:
175
  df[topic] = 0
176
- for i, ebd_content in enumerate(df['Embeddings']):
177
  if isinstance(ebd_content, torch.Tensor):
178
  cos_scores = util.cos_sim(ebd_content, torch.stack(list(category_df['Embeddings']), dim=0))[0]
179
  high_score_indices = [i for i, score in enumerate(cos_scores) if score > treshold]
180
  for j in high_score_indices:
181
- df.loc[i, category_df.loc[j, 'topic']] = float(cos_scores[j])
182
  return df
183
 
184
  def save_data(df, filename):
@@ -193,9 +193,10 @@ def classification(column, file_path, categories, treshold):
193
 
194
  # Initialize models
195
  model_ST = initialize_models()
196
-
197
  # Generate embeddings for df
198
  df = generate_embeddings(df, model_ST, column)
 
199
 
200
 
201
  category_df = process_categories(categories, model_ST)
 
173
  def match_categories(df, category_df, treshold=0.45):
174
  for topic in category_df['topic']:
175
  df[topic] = 0
176
+ for index, ebd_content in enumerate(df['Embeddings']):
177
  if isinstance(ebd_content, torch.Tensor):
178
  cos_scores = util.cos_sim(ebd_content, torch.stack(list(category_df['Embeddings']), dim=0))[0]
179
  high_score_indices = [i for i, score in enumerate(cos_scores) if score > treshold]
180
  for j in high_score_indices:
181
+ df.loc[index, category_df.loc[j, 'topic']] = 'float(cos_scores[j])'
182
  return df
183
 
184
  def save_data(df, filename):
 
193
 
194
  # Initialize models
195
  model_ST = initialize_models()
196
+ print('Generating Embeddings')
197
  # Generate embeddings for df
198
  df = generate_embeddings(df, model_ST, column)
199
+ print('Embeddings Generated')
200
 
201
 
202
  category_df = process_categories(categories, model_ST)